easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,3768 @@
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Implementation of tf.metrics module."""
15
+
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ from tensorflow.python.distribute import distribution_strategy_context
21
+ from tensorflow.python.eager import context
22
+ from tensorflow.python.framework import dtypes
23
+ from tensorflow.python.framework import ops
24
+ from tensorflow.python.framework import sparse_tensor
25
+ from tensorflow.python.ops import array_ops
26
+ from tensorflow.python.ops import check_ops
27
+ from tensorflow.python.ops import confusion_matrix
28
+ from tensorflow.python.ops import control_flow_ops
29
+ from tensorflow.python.ops import math_ops
30
+ from tensorflow.python.ops import nn
31
+ from tensorflow.python.ops import sets
32
+ from tensorflow.python.ops import sparse_ops
33
+ from tensorflow.python.ops import state_ops
34
+ from tensorflow.python.ops import variable_scope
35
+ from tensorflow.python.ops import weights_broadcast_ops
36
+ from tensorflow.python.platform import tf_logging as logging
37
+ from tensorflow.python.util.deprecation import deprecated
38
+ from tensorflow.python.util.tf_export import tf_export
39
+
40
+ # from tensorflow.python.training import distribution_strategy_context
41
+
42
+
43
+ def metric_variable(shape, dtype, validate_shape=True, name=None):
44
+ """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
45
+
46
+ If running in a `DistributionStrategy` context, the variable will be
47
+ "sync on read". This means:
48
+
49
+ * The returned object will be a container with separate variables
50
+ per replica of the model.
51
+
52
+ * When writing to the variable, e.g. using `assign_add` in a metric
53
+ update, the update will be applied to the variable local to the
54
+ replica.
55
+
56
+ * To get a metric's result value, we need to sum the variable values
57
+ across the replicas before computing the final answer. Furthermore,
58
+ the final answer should be computed once instead of in every
59
+ replica. Both of these are accomplished by running the computation
60
+ of the final result value inside
61
+ `distribution_strategy_context.get_replica_context().merge_call(fn)`.
62
+ Inside the `merge_call()`, ops are only added to the graph once
63
+ and access to a sync on read variable in a computation returns
64
+ the sum across all replicas.
65
+
66
+ Args:
67
+ shape: Shape of the created variable.
68
+ dtype: Type of the created variable.
69
+ validate_shape: (Optional) Whether shape validation is enabled for
70
+ the created variable.
71
+ name: (Optional) String name of the created variable.
72
+
73
+ Returns:
74
+ A (non-trainable) variable initialized to zero, or if inside a
75
+ `DistributionStrategy` scope a sync on read variable container.
76
+ """
77
+ # Note that synchronization "ON_READ" implies trainable=False.
78
+ return variable_scope.variable(
79
+ lambda: array_ops.zeros(shape, dtype),
80
+ trainable=False,
81
+ collections=[
82
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
83
+ ],
84
+ validate_shape=validate_shape,
85
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
86
+ aggregation=variable_scope.VariableAggregation.SUM,
87
+ name=name)
88
+
89
+
90
+ def _remove_squeezable_dimensions(predictions, labels, weights):
91
+ """Squeeze or expand last dim if needed.
92
+
93
+ Squeezes last dim of `predictions` or `labels` if their rank differs by 1
94
+ (using confusion_matrix.remove_squeezable_dimensions).
95
+ Squeezes or expands last dim of `weights` if its rank differs by 1 from the
96
+ new rank of `predictions`.
97
+
98
+ If `weights` is scalar, it is kept scalar.
99
+
100
+ This will use static shape if available. Otherwise, it will add graph
101
+ operations, which could result in a performance hit.
102
+
103
+ Args:
104
+ predictions: Predicted values, a `Tensor` of arbitrary dimensions.
105
+ labels: Optional label `Tensor` whose dimensions match `predictions`.
106
+ weights: Optional weight scalar or `Tensor` whose dimensions match
107
+ `predictions`.
108
+
109
+ Returns:
110
+ Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
111
+ the last dimension squeezed, `weights` could be extended by one dimension.
112
+ """
113
+ predictions = ops.convert_to_tensor(predictions)
114
+ if labels is not None:
115
+ labels, predictions = confusion_matrix.remove_squeezable_dimensions(
116
+ labels, predictions)
117
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
118
+
119
+ if weights is None:
120
+ return predictions, labels, None
121
+
122
+ weights = ops.convert_to_tensor(weights)
123
+ weights_shape = weights.get_shape()
124
+ weights_rank = weights_shape.ndims
125
+ if weights_rank == 0:
126
+ return predictions, labels, weights
127
+
128
+ predictions_shape = predictions.get_shape()
129
+ predictions_rank = predictions_shape.ndims
130
+ if (predictions_rank is not None) and (weights_rank is not None):
131
+ # Use static rank.
132
+ if weights_rank - predictions_rank == 1:
133
+ weights = array_ops.squeeze(weights, [-1])
134
+ elif predictions_rank - weights_rank == 1:
135
+ weights = array_ops.expand_dims(weights, [-1])
136
+ else:
137
+ # Use dynamic rank.
138
+ weights_rank_tensor = array_ops.rank(weights)
139
+ rank_diff = weights_rank_tensor - array_ops.rank(predictions)
140
+
141
+ def _maybe_expand_weights():
142
+ return control_flow_ops.cond(
143
+ math_ops.equal(rank_diff, -1),
144
+ lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
145
+
146
+ # Don't attempt squeeze if it will fail based on static check.
147
+ if ((weights_rank is not None) and
148
+ (not weights_shape.dims[-1].is_compatible_with(1))):
149
+ maybe_squeeze_weights = lambda: weights # noqa: E731
150
+ else:
151
+ maybe_squeeze_weights = lambda: array_ops.squeeze( # noqa: E731
152
+ weights, [-1])
153
+
154
+ def _maybe_adjust_weights():
155
+ return control_flow_ops.cond(
156
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
157
+ _maybe_expand_weights)
158
+
159
+ # If weights are scalar, do nothing. Otherwise, try to add or remove a
160
+ # dimension to match predictions.
161
+ weights = control_flow_ops.cond(
162
+ math_ops.equal(weights_rank_tensor, 0), lambda: weights,
163
+ _maybe_adjust_weights)
164
+ return predictions, labels, weights
165
+
166
+
167
+ def _maybe_expand_labels(labels, predictions):
168
+ """If necessary, expand `labels` along last dimension to match `predictions`.
169
+
170
+ Args:
171
+ labels: `Tensor` or `SparseTensor` with shape
172
+ [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
173
+ num_labels=1, in which case the result is an expanded `labels` with shape
174
+ [D1, ... DN, 1].
175
+ predictions: `Tensor` with shape [D1, ... DN, num_classes].
176
+
177
+ Returns:
178
+ `labels` with the same rank as `predictions`.
179
+
180
+ Raises:
181
+ ValueError: if `labels` has invalid shape.
182
+ """
183
+ with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
184
+ labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
185
+
186
+ # If sparse, expand sparse shape.
187
+ if isinstance(labels, sparse_tensor.SparseTensor):
188
+ return control_flow_ops.cond(
189
+ math_ops.equal(
190
+ array_ops.rank(predictions),
191
+ array_ops.size(labels.dense_shape) + 1),
192
+ lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda
193
+ labels,
194
+ shape=array_ops.concat((labels.dense_shape, (1,)), 0),
195
+ name=scope),
196
+ lambda: labels)
197
+
198
+ # Otherwise, try to use static shape.
199
+ labels_rank = labels.get_shape().ndims
200
+ if labels_rank is not None:
201
+ predictions_rank = predictions.get_shape().ndims
202
+ if predictions_rank is not None:
203
+ if predictions_rank == labels_rank:
204
+ return labels
205
+ if predictions_rank == labels_rank + 1:
206
+ return array_ops.expand_dims(labels, -1, name=scope)
207
+ raise ValueError(
208
+ 'Unexpected labels shape %s for predictions shape %s.' %
209
+ (labels.get_shape(), predictions.get_shape()))
210
+
211
+ # Otherwise, use dynamic shape.
212
+ return control_flow_ops.cond(
213
+ math_ops.equal(array_ops.rank(predictions),
214
+ array_ops.rank(labels) + 1),
215
+ lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
216
+
217
+
218
+ def _safe_scalar_div(numerator, denominator, name):
219
+ """Divides two values, returning 0 if the denominator is 0.
220
+
221
+ Args:
222
+ numerator: A scalar `float64` `Tensor`.
223
+ denominator: A scalar `float64` `Tensor`.
224
+ name: Name for the returned op.
225
+
226
+ Returns:
227
+ 0 if `denominator` == 0, else `numerator` / `denominator`
228
+ """
229
+ numerator.get_shape().with_rank_at_most(1)
230
+ denominator.get_shape().with_rank_at_most(1)
231
+ return math_ops.div_no_nan(numerator, denominator, name=name)
232
+
233
+
234
+ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
235
+ """Calculate a streaming confusion matrix.
236
+
237
+ Calculates a confusion matrix. For estimation over a stream of data,
238
+ the function creates an `update_op` operation.
239
+
240
+ Args:
241
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
242
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
243
+ predictions: A `Tensor` of prediction results for semantic labels, whose
244
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
245
+ flattened if its rank > 1.
246
+ num_classes: The possible number of labels the prediction task can
247
+ have. This value must be provided, since a confusion matrix of
248
+ dimension = [num_classes, num_classes] will be allocated.
249
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
250
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
251
+ be either `1`, or the same as the corresponding `labels` dimension).
252
+
253
+ Returns:
254
+ total_cm: A `Tensor` representing the confusion matrix.
255
+ update_op: An operation that increments the confusion matrix.
256
+ """
257
+ # Local variable to accumulate the predictions in the confusion matrix.
258
+ total_cm = metric_variable([num_classes, num_classes],
259
+ dtypes.float64,
260
+ name='total_confusion_matrix')
261
+
262
+ # Cast the type to int64 required by confusion_matrix_ops.
263
+ predictions = math_ops.cast(predictions, dtypes.int64)
264
+ labels = math_ops.cast(labels, dtypes.int64)
265
+ num_classes = math_ops.cast(num_classes, dtypes.int64)
266
+
267
+ # Flatten the input if its rank > 1.
268
+ if predictions.get_shape().ndims > 1:
269
+ predictions = array_ops.reshape(predictions, [-1])
270
+
271
+ if labels.get_shape().ndims > 1:
272
+ labels = array_ops.reshape(labels, [-1])
273
+
274
+ if (weights is not None) and (weights.get_shape().ndims > 1):
275
+ weights = array_ops.reshape(weights, [-1])
276
+
277
+ # Accumulate the prediction to current confusion matrix.
278
+ current_cm = confusion_matrix.confusion_matrix(
279
+ labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
280
+ update_op = state_ops.assign_add(total_cm, current_cm, use_locking=True)
281
+ return total_cm, update_op
282
+
283
+
284
+ def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
285
+ """Aggregate metric value across replicas."""
286
+
287
+ def fn(distribution, *a):
288
+ """Call `metric_value_fn` in the correct control flow context."""
289
+ if hasattr(distribution.extended, '_outer_control_flow_context'):
290
+ # If there was an outer context captured before this method was called,
291
+ # then we enter that context to create the metric value op. If the
292
+ # caputred context is `None`, ops.control_dependencies(None) gives the
293
+ # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
294
+ # captured context.
295
+ # This special handling is needed because sometimes the metric is created
296
+ # inside a while_loop (and perhaps a TPU rewrite context). But we don't
297
+ # want the value op to be evaluated every step or on the TPU. So we
298
+ # create it outside so that it can be evaluated at the end on the host,
299
+ # once the update ops have been evaluted.
300
+
301
+ # pylint: disable=protected-access
302
+ if distribution.extended._outer_control_flow_context is None:
303
+ with ops.control_dependencies(None):
304
+ metric_value = metric_value_fn(distribution, *a)
305
+ else:
306
+ distribution.extended._outer_control_flow_context.Enter()
307
+ metric_value = metric_value_fn(distribution, *a)
308
+ distribution.extended._outer_control_flow_context.Exit()
309
+ # pylint: enable=protected-access
310
+ else:
311
+ metric_value = metric_value_fn(distribution, *a)
312
+ if metrics_collections:
313
+ ops.add_to_collections(metrics_collections, metric_value)
314
+ return metric_value
315
+
316
+ return distribution_strategy_context.get_replica_context().merge_call(
317
+ fn, args=args)
318
+
319
+
320
+ @tf_export(v1=['metrics.mean'])
321
+ def mean(values,
322
+ weights=None,
323
+ metrics_collections=None,
324
+ updates_collections=None,
325
+ name=None):
326
+ """Computes the (weighted) mean of the given values.
327
+
328
+ The `mean` function creates two local variables, `total` and `count`
329
+ that are used to compute the average of `values`. This average is ultimately
330
+ returned as `mean` which is an idempotent operation that simply divides
331
+ `total` by `count`.
332
+
333
+ For estimation of the metric over a stream of data, the function creates an
334
+ `update_op` operation that updates these variables and returns the `mean`.
335
+ `update_op` increments `total` with the reduced sum of the product of `values`
336
+ and `weights`, and it increments `count` with the reduced sum of `weights`.
337
+
338
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
339
+
340
+ Args:
341
+ values: A `Tensor` of arbitrary dimensions.
342
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
343
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
344
+ be either `1`, or the same as the corresponding `values` dimension).
345
+ metrics_collections: An optional list of collections that `mean`
346
+ should be added to.
347
+ updates_collections: An optional list of collections that `update_op`
348
+ should be added to.
349
+ name: An optional variable_scope name.
350
+
351
+ Returns:
352
+ mean: A `Tensor` representing the current mean, the value of `total` divided
353
+ by `count`.
354
+ update_op: An operation that increments the `total` and `count` variables
355
+ appropriately and whose value matches `mean_value`.
356
+
357
+ Raises:
358
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
359
+ or if either `metrics_collections` or `updates_collections` are not a list
360
+ or tuple.
361
+ RuntimeError: If eager execution is enabled.
362
+ """
363
+ if context.executing_eagerly():
364
+ raise RuntimeError('tf.metrics.mean is not supported when eager execution '
365
+ 'is enabled.')
366
+
367
+ with variable_scope.variable_scope(name, 'mean', (values, weights)):
368
+ values = math_ops.cast(values, dtypes.float32)
369
+
370
+ total = metric_variable([], dtypes.float32, name='total')
371
+ count = metric_variable([], dtypes.float32, name='count')
372
+
373
+ if weights is None:
374
+ num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
375
+ else:
376
+ values, _, weights = _remove_squeezable_dimensions(
377
+ predictions=values, labels=None, weights=weights)
378
+ weights = weights_broadcast_ops.broadcast_weights(
379
+ math_ops.cast(weights, dtypes.float32), values)
380
+ values = math_ops.multiply(values, weights)
381
+ num_values = math_ops.reduce_sum(weights)
382
+
383
+ update_total_op = state_ops.assign_add(
384
+ total, math_ops.reduce_sum(values), use_locking=True)
385
+ with ops.control_dependencies([values]):
386
+ update_count_op = state_ops.assign_add(
387
+ count, num_values, use_locking=True)
388
+
389
+ def compute_mean(_, t, c):
390
+ return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
391
+
392
+ mean_t = _aggregate_across_replicas(metrics_collections, compute_mean,
393
+ total, count)
394
+ update_op = math_ops.div_no_nan(
395
+ update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
396
+
397
+ if updates_collections:
398
+ ops.add_to_collections(updates_collections, update_op)
399
+
400
+ return mean_t, update_op
401
+
402
+
403
+ @tf_export(v1=['metrics.accuracy'])
404
+ def accuracy(labels,
405
+ predictions,
406
+ weights=None,
407
+ metrics_collections=None,
408
+ updates_collections=None,
409
+ name=None):
410
+ """Calculates how often `predictions` matches `labels`.
411
+
412
+ The `accuracy` function creates two local variables, `total` and
413
+ `count` that are used to compute the frequency with which `predictions`
414
+ matches `labels`. This frequency is ultimately returned as `accuracy`: an
415
+ idempotent operation that simply divides `total` by `count`.
416
+
417
+ For estimation of the metric over a stream of data, the function creates an
418
+ `update_op` operation that updates these variables and returns the `accuracy`.
419
+ Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
420
+ where the corresponding elements of `predictions` and `labels` match and 0.0
421
+ otherwise. Then `update_op` increments `total` with the reduced sum of the
422
+ product of `weights` and `is_correct`, and it increments `count` with the
423
+ reduced sum of `weights`.
424
+
425
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
426
+
427
+ Args:
428
+ labels: The ground truth values, a `Tensor` whose shape matches
429
+ `predictions`.
430
+ predictions: The predicted values, a `Tensor` of any shape.
431
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
432
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
433
+ be either `1`, or the same as the corresponding `labels` dimension).
434
+ metrics_collections: An optional list of collections that `accuracy` should
435
+ be added to.
436
+ updates_collections: An optional list of collections that `update_op` should
437
+ be added to.
438
+ name: An optional variable_scope name.
439
+
440
+ Returns:
441
+ accuracy: A `Tensor` representing the accuracy, the value of `total` divided
442
+ by `count`.
443
+ update_op: An operation that increments the `total` and `count` variables
444
+ appropriately and whose value matches `accuracy`.
445
+
446
+ Raises:
447
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
448
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
449
+ either `metrics_collections` or `updates_collections` are not a list or
450
+ tuple.
451
+ RuntimeError: If eager execution is enabled.
452
+ """
453
+ if context.executing_eagerly():
454
+ raise RuntimeError('tf.metrics.accuracy is not supported when eager '
455
+ 'execution is enabled.')
456
+
457
+ predictions, labels, weights = _remove_squeezable_dimensions(
458
+ predictions=predictions, labels=labels, weights=weights)
459
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
460
+ if labels.dtype != predictions.dtype:
461
+ predictions = math_ops.cast(predictions, labels.dtype)
462
+ is_correct = math_ops.cast(
463
+ math_ops.equal(predictions, labels), dtypes.float32)
464
+ return mean(is_correct, weights, metrics_collections, updates_collections,
465
+ name or 'accuracy')
466
+
467
+
468
+ def _confusion_matrix_at_thresholds(labels,
469
+ predictions,
470
+ thresholds,
471
+ weights=None,
472
+ includes=None):
473
+ """Computes true_positives, false_negatives, true_negatives, false_positives.
474
+
475
+ This function creates up to four local variables, `true_positives`,
476
+ `true_negatives`, `false_positives` and `false_negatives`.
477
+ `true_positive[i]` is defined as the total weight of values in `predictions`
478
+ above `thresholds[i]` whose corresponding entry in `labels` is `True`.
479
+ `false_negatives[i]` is defined as the total weight of values in `predictions`
480
+ at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
481
+ `true_negatives[i]` is defined as the total weight of values in `predictions`
482
+ at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
483
+ `false_positives[i]` is defined as the total weight of values in `predictions`
484
+ above `thresholds[i]` whose corresponding entry in `labels` is `False`.
485
+
486
+ For estimation of these metrics over a stream of data, for each metric the
487
+ function respectively creates an `update_op` operation that updates the
488
+ variable and returns its value.
489
+
490
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
491
+
492
+ Args:
493
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
494
+ `bool`.
495
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
496
+ are in the range `[0, 1]`.
497
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
498
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
499
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
500
+ be either `1`, or the same as the corresponding `labels` dimension).
501
+ includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
502
+ default to all four.
503
+
504
+ Returns:
505
+ values: Dict of variables of shape `[len(thresholds)]`. Keys are from
506
+ `includes`.
507
+ update_ops: Dict of operations that increments the `values`. Keys are from
508
+ `includes`.
509
+
510
+ Raises:
511
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
512
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
513
+ `includes` contains invalid keys.
514
+ """
515
+ all_includes = ('tp', 'fn', 'tn', 'fp')
516
+ if includes is None:
517
+ includes = all_includes
518
+ else:
519
+ for include in includes:
520
+ if include not in all_includes:
521
+ raise ValueError('Invalid key: %s.' % include)
522
+
523
+ with ops.control_dependencies([
524
+ check_ops.assert_greater_equal(
525
+ predictions,
526
+ math_ops.cast(0.0, dtype=predictions.dtype),
527
+ message='predictions must be in [0, 1]'),
528
+ check_ops.assert_less_equal(
529
+ predictions,
530
+ math_ops.cast(1.0, dtype=predictions.dtype),
531
+ message='predictions must be in [0, 1]')
532
+ ]):
533
+ predictions, labels, weights = _remove_squeezable_dimensions(
534
+ predictions=math_ops.cast(predictions, dtypes.float32),
535
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
536
+ weights=weights)
537
+
538
+ num_thresholds = len(thresholds)
539
+
540
+ # Reshape predictions and labels.
541
+ predictions_2d = array_ops.reshape(predictions, [-1, 1])
542
+ labels_2d = array_ops.reshape(
543
+ math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
544
+
545
+ # Use static shape if known.
546
+ num_predictions = predictions_2d.get_shape().as_list()[0]
547
+
548
+ # Otherwise use dynamic shape.
549
+ if num_predictions is None:
550
+ num_predictions = array_ops.shape(predictions_2d)[0]
551
+ thresh_tiled = array_ops.tile(
552
+ array_ops.expand_dims(array_ops.constant(thresholds), [1]),
553
+ array_ops.stack([1, num_predictions]))
554
+
555
+ # Tile the predictions after thresholding them across different thresholds.
556
+ pred_is_pos = math_ops.greater(
557
+ array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
558
+ thresh_tiled)
559
+ if ('fn' in includes) or ('tn' in includes):
560
+ pred_is_neg = math_ops.logical_not(pred_is_pos)
561
+
562
+ # Tile labels by number of thresholds
563
+ label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
564
+ if ('fp' in includes) or ('tn' in includes):
565
+ label_is_neg = math_ops.logical_not(label_is_pos)
566
+
567
+ if weights is not None:
568
+ weights = weights_broadcast_ops.broadcast_weights(
569
+ math_ops.cast(weights, dtypes.float32), predictions)
570
+ weights_tiled = array_ops.tile(
571
+ array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
572
+ thresh_tiled.get_shape().assert_is_compatible_with(
573
+ weights_tiled.get_shape())
574
+ else:
575
+ weights_tiled = None
576
+
577
+ values = {}
578
+ update_ops = {}
579
+
580
+ if 'tp' in includes:
581
+ true_p = metric_variable([num_thresholds],
582
+ dtypes.float32,
583
+ name='true_positives')
584
+ is_true_positive = math_ops.cast(
585
+ math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
586
+ if weights_tiled is not None:
587
+ is_true_positive *= weights_tiled
588
+ update_ops['tp'] = state_ops.assign_add(
589
+ true_p, math_ops.reduce_sum(is_true_positive, 1), use_locking=True)
590
+ values['tp'] = true_p
591
+
592
+ if 'fn' in includes:
593
+ false_n = metric_variable([num_thresholds],
594
+ dtypes.float32,
595
+ name='false_negatives')
596
+ is_false_negative = math_ops.cast(
597
+ math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
598
+ if weights_tiled is not None:
599
+ is_false_negative *= weights_tiled
600
+ update_ops['fn'] = state_ops.assign_add(
601
+ false_n, math_ops.reduce_sum(is_false_negative, 1), use_locking=True)
602
+ values['fn'] = false_n
603
+
604
+ if 'tn' in includes:
605
+ true_n = metric_variable([num_thresholds],
606
+ dtypes.float32,
607
+ name='true_negatives')
608
+ is_true_negative = math_ops.cast(
609
+ math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
610
+ if weights_tiled is not None:
611
+ is_true_negative *= weights_tiled
612
+ update_ops['tn'] = state_ops.assign_add(
613
+ true_n, math_ops.reduce_sum(is_true_negative, 1), use_locking=True)
614
+ values['tn'] = true_n
615
+
616
+ if 'fp' in includes:
617
+ false_p = metric_variable([num_thresholds],
618
+ dtypes.float32,
619
+ name='false_positives')
620
+ is_false_positive = math_ops.cast(
621
+ math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
622
+ if weights_tiled is not None:
623
+ is_false_positive *= weights_tiled
624
+ update_ops['fp'] = state_ops.assign_add(
625
+ false_p, math_ops.reduce_sum(is_false_positive, 1), use_locking=True)
626
+ values['fp'] = false_p
627
+
628
+ return values, update_ops
629
+
630
+
631
+ def _aggregate_variable(v, collections):
632
+ f = lambda distribution, value: distribution.extended.read_var( # noqa: E731
633
+ value)
634
+ return _aggregate_across_replicas(collections, f, v)
635
+
636
+
637
+ @tf_export(v1=['metrics.auc'])
638
+ def auc(labels,
639
+ predictions,
640
+ weights=None,
641
+ num_thresholds=200,
642
+ metrics_collections=None,
643
+ updates_collections=None,
644
+ curve='ROC',
645
+ name=None,
646
+ summation_method='trapezoidal',
647
+ thresholds=None):
648
+ """Computes the approximate AUC via a Riemann sum.
649
+
650
+ The `auc` function creates four local variables, `true_positives`,
651
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
652
+ compute the AUC. To discretize the AUC curve, a linearly spaced set of
653
+ thresholds is used to compute pairs of recall and precision values. The area
654
+ under the ROC-curve is therefore computed using the height of the recall
655
+ values by the false positive rate, while the area under the PR-curve is the
656
+ computed using the height of the precision values by the recall.
657
+
658
+ This value is ultimately returned as `auc`, an idempotent operation that
659
+ computes the area under a discretized curve of precision versus recall values
660
+ (computed using the aforementioned variables). The `num_thresholds` variable
661
+ controls the degree of discretization with larger numbers of thresholds more
662
+ closely approximating the true AUC. The quality of the approximation may vary
663
+ dramatically depending on `num_thresholds`.
664
+
665
+ For best results, `predictions` should be distributed approximately uniformly
666
+ in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
667
+ approximation may be poor if this is not the case. Setting `summation_method`
668
+ to 'minoring' or 'majoring' can help quantify the error in the approximation
669
+ by providing lower or upper bound estimate of the AUC. The `thresholds`
670
+ parameter can be used to manually specify thresholds which split the
671
+ predictions more evenly.
672
+
673
+ For estimation of the metric over a stream of data, the function creates an
674
+ `update_op` operation that updates these variables and returns the `auc`.
675
+
676
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
677
+
678
+ Args:
679
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
680
+ `bool`.
681
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
682
+ are in the range `[0, 1]`.
683
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
684
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
685
+ be either `1`, or the same as the corresponding `labels` dimension).
686
+ num_thresholds: The number of thresholds to use when discretizing the roc
687
+ curve.
688
+ metrics_collections: An optional list of collections that `auc` should be
689
+ added to.
690
+ updates_collections: An optional list of collections that `update_op` should
691
+ be added to.
692
+ curve: Specifies the name of the curve to be computed, 'ROC' [default] or
693
+ 'PR' for the Precision-Recall-curve.
694
+ name: An optional variable_scope name.
695
+ summation_method: Specifies the Riemann summation method used
696
+ (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
697
+ applies the trapezoidal rule; 'careful_interpolation', a variant of it
698
+ differing only by a more correct interpolation scheme for PR-AUC -
699
+ interpolating (true/false) positives but not the ratio that is precision;
700
+ 'minoring' that applies left summation for increasing intervals and right
701
+ summation for decreasing intervals; 'majoring' that does the opposite.
702
+ Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
703
+ (to be deprecated soon) as it applies the same method for ROC, and a
704
+ better one (see Davis & Goadrich 2006 for details) for the PR curve.
705
+ thresholds: An optional list of floating point values to use as the
706
+ thresholds for discretizing the curve. If set, the `num_thresholds`
707
+ parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
708
+ equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
709
+ automatically included with these to correctly handle predictions equal to
710
+ exactly 0 or 1.
711
+
712
+ Returns:
713
+ auc: A scalar `Tensor` representing the current area-under-curve.
714
+ update_op: An operation that increments the `true_positives`,
715
+ `true_negatives`, `false_positives` and `false_negatives` variables
716
+ appropriately and whose value matches `auc`.
717
+
718
+ Raises:
719
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
720
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
721
+ either `metrics_collections` or `updates_collections` are not a list or
722
+ tuple.
723
+ RuntimeError: If eager execution is enabled.
724
+ """
725
+ print('use_tf_auc')
726
+ if context.executing_eagerly():
727
+ raise RuntimeError('tf.metrics.auc is not supported when eager execution '
728
+ 'is enabled.')
729
+
730
+ with variable_scope.variable_scope(name, 'auc',
731
+ (labels, predictions, weights)):
732
+ if curve != 'ROC' and curve != 'PR':
733
+ raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
734
+
735
+ kepsilon = 1e-7 # To account for floating point imprecisions.
736
+ if thresholds is not None:
737
+ # If specified, use the supplied thresholds.
738
+ thresholds = sorted(thresholds)
739
+ num_thresholds = len(thresholds) + 2
740
+ else:
741
+ # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
742
+ # (0, 1).
743
+ thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
744
+ for i in range(num_thresholds - 2)]
745
+
746
+ # Add an endpoint "threshold" below zero and above one for either threshold
747
+ # method.
748
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
749
+
750
+ values, update_ops = _confusion_matrix_at_thresholds(
751
+ labels, predictions, thresholds, weights)
752
+
753
+ # Add epsilons to avoid dividing by 0.
754
+ epsilon = 1.0e-6
755
+
756
+ def interpolate_pr_auc(tp, fp, fn):
757
+ """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
758
+
759
+ Note here we derive & use a closed formula not present in the paper
760
+ - as follows:
761
+ Modeling all of TP (true positive weight),
762
+ FP (false positive weight) and their sum P = TP + FP (positive weight)
763
+ as varying linearly within each interval [A, B] between successive
764
+ thresholds, we get
765
+ Precision = (TP_A + slope * (P - P_A)) / P
766
+ with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
767
+ The area within the interval is thus (slope / total_pos_weight) times
768
+ int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
769
+ int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
770
+ where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
771
+ int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
772
+ Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
773
+ slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
774
+ where dTP == TP_B - TP_A.
775
+ Note that when P_A == 0 the above calculation simplifies into
776
+ int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
777
+ which is really equivalent to imputing constant precision throughout the
778
+ first bucket having >0 true positives.
779
+
780
+ Args:
781
+ tp: true positive counts
782
+ fp: false positive counts
783
+ fn: false negative counts
784
+ Returns:
785
+ pr_auc: an approximation of the area under the P-R curve.
786
+ """
787
+ dtp = tp[:num_thresholds - 1] - tp[1:]
788
+ p = tp + fp
789
+ prec_slope = math_ops.div_no_nan(
790
+ dtp,
791
+ math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
792
+ name='prec_slope')
793
+ intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
794
+ safe_p_ratio = array_ops.where(
795
+ math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
796
+ math_ops.div_no_nan(
797
+ p[:num_thresholds - 1],
798
+ math_ops.maximum(p[1:], 0),
799
+ name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
800
+ return math_ops.reduce_sum(
801
+ math_ops.div_no_nan(
802
+ prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
803
+ math_ops.maximum(tp[1:] + fn[1:], 0),
804
+ name='pr_auc_increment'),
805
+ name='interpolate_pr_auc')
806
+
807
+ def compute_auc(tp, fn, tn, fp, name):
808
+ """Computes the roc-auc or pr-auc based on confusion counts."""
809
+ if curve == 'PR':
810
+ if summation_method == 'trapezoidal':
811
+ logging.warning(
812
+ 'Trapezoidal rule is known to produce incorrect PR-AUCs; '
813
+ 'please switch to "careful_interpolation" instead.')
814
+ elif summation_method == 'careful_interpolation':
815
+ # This one is a bit tricky and is handled separately.
816
+ return interpolate_pr_auc(tp, fp, fn)
817
+ rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
818
+ if curve == 'ROC':
819
+ fp_rate = math_ops.div(fp, fp + tn + epsilon)
820
+ x = fp_rate
821
+ y = rec
822
+ else: # curve == 'PR'.
823
+ prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
824
+ x = rec
825
+ y = prec
826
+ if summation_method in ('trapezoidal', 'careful_interpolation'):
827
+ # Note that the case ('PR', 'careful_interpolation') has been handled
828
+ # above.
829
+ return math_ops.reduce_sum(
830
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
831
+ (y[:num_thresholds - 1] + y[1:]) / 2.),
832
+ name=name)
833
+ elif summation_method == 'minoring':
834
+ return math_ops.reduce_sum(
835
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
836
+ math_ops.minimum(y[:num_thresholds - 1], y[1:])),
837
+ name=name)
838
+ elif summation_method == 'majoring':
839
+ return math_ops.reduce_sum(
840
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
841
+ math_ops.maximum(y[:num_thresholds - 1], y[1:])),
842
+ name=name)
843
+ else:
844
+ raise ValueError('Invalid summation_method: %s' % summation_method)
845
+
846
+ # sum up the areas of all the trapeziums
847
+ def compute_auc_value(_, values):
848
+ return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
849
+ 'value')
850
+
851
+ auc_value = _aggregate_across_replicas(metrics_collections,
852
+ compute_auc_value, values)
853
+ update_op = compute_auc(update_ops['tp'], update_ops['fn'],
854
+ update_ops['tn'], update_ops['fp'], 'update_op')
855
+
856
+ if updates_collections:
857
+ ops.add_to_collections(updates_collections, update_op)
858
+
859
+ return auc_value, update_op
860
+
861
+
862
+ @tf_export(v1=['metrics.mean_absolute_error'])
863
+ def mean_absolute_error(labels,
864
+ predictions,
865
+ weights=None,
866
+ metrics_collections=None,
867
+ updates_collections=None,
868
+ name=None):
869
+ """Computes the mean absolute error between the labels and predictions.
870
+
871
+ The `mean_absolute_error` function creates two local variables,
872
+ `total` and `count` that are used to compute the mean absolute error. This
873
+ average is weighted by `weights`, and it is ultimately returned as
874
+ `mean_absolute_error`: an idempotent operation that simply divides `total` by
875
+ `count`.
876
+
877
+ For estimation of the metric over a stream of data, the function creates an
878
+ `update_op` operation that updates these variables and returns the
879
+ `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
880
+ absolute value of the differences between `predictions` and `labels`. Then
881
+ `update_op` increments `total` with the reduced sum of the product of
882
+ `weights` and `absolute_errors`, and it increments `count` with the reduced
883
+ sum of `weights`
884
+
885
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
886
+
887
+ Args:
888
+ labels: A `Tensor` of the same shape as `predictions`.
889
+ predictions: A `Tensor` of arbitrary shape.
890
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
891
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
892
+ be either `1`, or the same as the corresponding `labels` dimension).
893
+ metrics_collections: An optional list of collections that
894
+ `mean_absolute_error` should be added to.
895
+ updates_collections: An optional list of collections that `update_op` should
896
+ be added to.
897
+ name: An optional variable_scope name.
898
+
899
+ Returns:
900
+ mean_absolute_error: A `Tensor` representing the current mean, the value of
901
+ `total` divided by `count`.
902
+ update_op: An operation that increments the `total` and `count` variables
903
+ appropriately and whose value matches `mean_absolute_error`.
904
+
905
+ Raises:
906
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
907
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
908
+ either `metrics_collections` or `updates_collections` are not a list or
909
+ tuple.
910
+ RuntimeError: If eager execution is enabled.
911
+ """
912
+ if context.executing_eagerly():
913
+ raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
914
+ 'when eager execution is enabled.')
915
+
916
+ predictions, labels, weights = _remove_squeezable_dimensions(
917
+ predictions=predictions, labels=labels, weights=weights)
918
+ absolute_errors = math_ops.abs(predictions - labels)
919
+ return mean(absolute_errors, weights, metrics_collections,
920
+ updates_collections, name or 'mean_absolute_error')
921
+
922
+
923
+ @tf_export(v1=['metrics.mean_cosine_distance'])
924
+ def mean_cosine_distance(labels,
925
+ predictions,
926
+ dim,
927
+ weights=None,
928
+ metrics_collections=None,
929
+ updates_collections=None,
930
+ name=None):
931
+ """Computes the cosine distance between the labels and predictions.
932
+
933
+ The `mean_cosine_distance` function creates two local variables,
934
+ `total` and `count` that are used to compute the average cosine distance
935
+ between `predictions` and `labels`. This average is weighted by `weights`,
936
+ and it is ultimately returned as `mean_distance`, which is an idempotent
937
+ operation that simply divides `total` by `count`.
938
+
939
+ For estimation of the metric over a stream of data, the function creates an
940
+ `update_op` operation that updates these variables and returns the
941
+ `mean_distance`.
942
+
943
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
944
+
945
+ Args:
946
+ labels: A `Tensor` of arbitrary shape.
947
+ predictions: A `Tensor` of the same shape as `labels`.
948
+ dim: The dimension along which the cosine distance is computed.
949
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
950
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
951
+ be either `1`, or the same as the corresponding `labels` dimension). Also,
952
+ dimension `dim` must be `1`.
953
+ metrics_collections: An optional list of collections that the metric
954
+ value variable should be added to.
955
+ updates_collections: An optional list of collections that the metric update
956
+ ops should be added to.
957
+ name: An optional variable_scope name.
958
+
959
+ Returns:
960
+ mean_distance: A `Tensor` representing the current mean, the value of
961
+ `total` divided by `count`.
962
+ update_op: An operation that increments the `total` and `count` variables
963
+ appropriately.
964
+
965
+ Raises:
966
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
967
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
968
+ either `metrics_collections` or `updates_collections` are not a list or
969
+ tuple.
970
+ RuntimeError: If eager execution is enabled.
971
+ """
972
+ if context.executing_eagerly():
973
+ raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
974
+ 'eager execution is enabled.')
975
+
976
+ predictions, labels, weights = _remove_squeezable_dimensions(
977
+ predictions=predictions, labels=labels, weights=weights)
978
+ radial_diffs = math_ops.multiply(predictions, labels)
979
+ radial_diffs = math_ops.reduce_sum(
980
+ radial_diffs, axis=[
981
+ dim,
982
+ ], keepdims=True)
983
+ mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
984
+ 'mean_cosine_distance')
985
+ mean_distance = math_ops.subtract(1.0, mean_distance)
986
+ update_op = math_ops.subtract(1.0, update_op)
987
+
988
+ if metrics_collections:
989
+ ops.add_to_collections(metrics_collections, mean_distance)
990
+
991
+ if updates_collections:
992
+ ops.add_to_collections(updates_collections, update_op)
993
+
994
+ return mean_distance, update_op
995
+
996
+
997
+ @tf_export(v1=['metrics.mean_per_class_accuracy'])
998
+ def mean_per_class_accuracy(labels,
999
+ predictions,
1000
+ num_classes,
1001
+ weights=None,
1002
+ metrics_collections=None,
1003
+ updates_collections=None,
1004
+ name=None):
1005
+ """Calculates the mean of the per-class accuracies.
1006
+
1007
+ Calculates the accuracy for each class, then takes the mean of that.
1008
+
1009
+ For estimation of the metric over a stream of data, the function creates an
1010
+ `update_op` operation that updates the accuracy of each class and returns
1011
+ them.
1012
+
1013
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1014
+
1015
+ Args:
1016
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
1017
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1018
+ predictions: A `Tensor` of prediction results for semantic labels, whose
1019
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
1020
+ flattened if its rank > 1.
1021
+ num_classes: The possible number of labels the prediction task can
1022
+ have. This value must be provided, since two variables with shape =
1023
+ [num_classes] will be allocated.
1024
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1025
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1026
+ be either `1`, or the same as the corresponding `labels` dimension).
1027
+ metrics_collections: An optional list of collections that
1028
+ `mean_per_class_accuracy'
1029
+ should be added to.
1030
+ updates_collections: An optional list of collections `update_op` should be
1031
+ added to.
1032
+ name: An optional variable_scope name.
1033
+
1034
+ Returns:
1035
+ mean_accuracy: A `Tensor` representing the mean per class accuracy.
1036
+ update_op: An operation that updates the accuracy tensor.
1037
+
1038
+ Raises:
1039
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1040
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1041
+ either `metrics_collections` or `updates_collections` are not a list or
1042
+ tuple.
1043
+ RuntimeError: If eager execution is enabled.
1044
+ """
1045
+ if context.executing_eagerly():
1046
+ raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1047
+ 'when eager execution is enabled.')
1048
+
1049
+ with variable_scope.variable_scope(name, 'mean_accuracy',
1050
+ (predictions, labels, weights)):
1051
+ labels = math_ops.cast(labels, dtypes.int64)
1052
+
1053
+ # Flatten the input if its rank > 1.
1054
+ if labels.get_shape().ndims > 1:
1055
+ labels = array_ops.reshape(labels, [-1])
1056
+
1057
+ if predictions.get_shape().ndims > 1:
1058
+ predictions = array_ops.reshape(predictions, [-1])
1059
+
1060
+ # Check if shape is compatible.
1061
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1062
+
1063
+ total = metric_variable([num_classes], dtypes.float32, name='total')
1064
+ count = metric_variable([num_classes], dtypes.float32, name='count')
1065
+
1066
+ ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1067
+
1068
+ if labels.dtype != predictions.dtype:
1069
+ predictions = math_ops.cast(predictions, labels.dtype)
1070
+ is_correct = math_ops.cast(
1071
+ math_ops.equal(predictions, labels), dtypes.float32)
1072
+
1073
+ if weights is not None:
1074
+ if weights.get_shape().ndims > 1:
1075
+ weights = array_ops.reshape(weights, [-1])
1076
+ weights = math_ops.cast(weights, dtypes.float32)
1077
+
1078
+ is_correct *= weights
1079
+ ones *= weights
1080
+
1081
+ update_total_op = state_ops.scatter_add(total, labels, ones)
1082
+ update_count_op = state_ops.scatter_add(count, labels, is_correct)
1083
+
1084
+ def compute_mean_accuracy(_, count, total):
1085
+ per_class_accuracy = math_ops.div_no_nan(
1086
+ count, math_ops.maximum(total, 0), name=None)
1087
+ mean_accuracy_v = math_ops.reduce_mean(
1088
+ per_class_accuracy, name='mean_accuracy')
1089
+ return mean_accuracy_v
1090
+
1091
+ mean_accuracy_v = _aggregate_across_replicas(metrics_collections,
1092
+ compute_mean_accuracy, count,
1093
+ total)
1094
+
1095
+ update_op = math_ops.div_no_nan(
1096
+ update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1097
+ if updates_collections:
1098
+ ops.add_to_collections(updates_collections, update_op)
1099
+
1100
+ return mean_accuracy_v, update_op
1101
+
1102
+
1103
+ @tf_export(v1=['metrics.mean_iou'])
1104
+ def mean_iou(labels,
1105
+ predictions,
1106
+ num_classes,
1107
+ weights=None,
1108
+ metrics_collections=None,
1109
+ updates_collections=None,
1110
+ name=None):
1111
+ """Calculate per-step mean Intersection-Over-Union (mIOU).
1112
+
1113
+ Mean Intersection-Over-Union is a common evaluation metric for
1114
+ semantic image segmentation, which first computes the IOU for each
1115
+ semantic class and then computes the average over classes.
1116
+ IOU is defined as follows:
1117
+ IOU = true_positive / (true_positive + false_positive + false_negative).
1118
+ The predictions are accumulated in a confusion matrix, weighted by `weights`,
1119
+ and mIOU is then calculated from it.
1120
+
1121
+ For estimation of the metric over a stream of data, the function creates an
1122
+ `update_op` operation that updates these variables and returns the `mean_iou`.
1123
+
1124
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1125
+
1126
+ Args:
1127
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
1128
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1129
+ predictions: A `Tensor` of prediction results for semantic labels, whose
1130
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
1131
+ flattened if its rank > 1.
1132
+ num_classes: The possible number of labels the prediction task can
1133
+ have. This value must be provided, since a confusion matrix of
1134
+ dimension = [num_classes, num_classes] will be allocated.
1135
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1136
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1137
+ be either `1`, or the same as the corresponding `labels` dimension).
1138
+ metrics_collections: An optional list of collections that `mean_iou`
1139
+ should be added to.
1140
+ updates_collections: An optional list of collections `update_op` should be
1141
+ added to.
1142
+ name: An optional variable_scope name.
1143
+
1144
+ Returns:
1145
+ mean_iou: A `Tensor` representing the mean intersection-over-union.
1146
+ update_op: An operation that increments the confusion matrix.
1147
+
1148
+ Raises:
1149
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1150
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1151
+ either `metrics_collections` or `updates_collections` are not a list or
1152
+ tuple.
1153
+ RuntimeError: If eager execution is enabled.
1154
+ """
1155
+ if context.executing_eagerly():
1156
+ raise RuntimeError('tf.metrics.mean_iou is not supported when '
1157
+ 'eager execution is enabled.')
1158
+
1159
+ with variable_scope.variable_scope(name, 'mean_iou',
1160
+ (predictions, labels, weights)):
1161
+ # Check if shape is compatible.
1162
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1163
+
1164
+ total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1165
+ num_classes, weights)
1166
+
1167
+ def compute_mean_iou(_, total_cm):
1168
+ """Compute the mean intersection-over-union via the confusion matrix."""
1169
+ sum_over_row = math_ops.cast(
1170
+ math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1171
+ sum_over_col = math_ops.cast(
1172
+ math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1173
+ cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1174
+ denominator = sum_over_row + sum_over_col - cm_diag
1175
+
1176
+ # The mean is only computed over classes that appear in the
1177
+ # label or prediction tensor. If the denominator is 0, we need to
1178
+ # ignore the class.
1179
+ num_valid_entries = math_ops.reduce_sum(
1180
+ math_ops.cast(
1181
+ math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1182
+
1183
+ # If the value of the denominator is 0, set it to 1 to avoid
1184
+ # zero division.
1185
+ denominator = array_ops.where(
1186
+ math_ops.greater(denominator, 0), denominator,
1187
+ array_ops.ones_like(denominator))
1188
+ iou = math_ops.div(cm_diag, denominator)
1189
+
1190
+ # If the number of valid entries is 0 (no classes) we return 0.
1191
+ result = array_ops.where(
1192
+ math_ops.greater(num_valid_entries, 0),
1193
+ math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1194
+ return result
1195
+
1196
+ # TODO(priyag): Use outside_compilation if in TPU context.
1197
+ mean_iou_v = _aggregate_across_replicas(metrics_collections,
1198
+ compute_mean_iou, total_cm)
1199
+
1200
+ if updates_collections:
1201
+ ops.add_to_collections(updates_collections, update_op)
1202
+
1203
+ return mean_iou_v, update_op
1204
+
1205
+
1206
+ @tf_export(v1=['metrics.mean_relative_error'])
1207
+ def mean_relative_error(labels,
1208
+ predictions,
1209
+ normalizer,
1210
+ weights=None,
1211
+ metrics_collections=None,
1212
+ updates_collections=None,
1213
+ name=None):
1214
+ """Computes the mean relative error by normalizing with the given values.
1215
+
1216
+ The `mean_relative_error` function creates two local variables,
1217
+ `total` and `count` that are used to compute the mean relative absolute error.
1218
+ This average is weighted by `weights`, and it is ultimately returned as
1219
+ `mean_relative_error`: an idempotent operation that simply divides `total` by
1220
+ `count`.
1221
+
1222
+ For estimation of the metric over a stream of data, the function creates an
1223
+ `update_op` operation that updates these variables and returns the
1224
+ `mean_reative_error`. Internally, a `relative_errors` operation divides the
1225
+ absolute value of the differences between `predictions` and `labels` by the
1226
+ `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1227
+ product of `weights` and `relative_errors`, and it increments `count` with the
1228
+ reduced sum of `weights`.
1229
+
1230
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1231
+
1232
+ Args:
1233
+ labels: A `Tensor` of the same shape as `predictions`.
1234
+ predictions: A `Tensor` of arbitrary shape.
1235
+ normalizer: A `Tensor` of the same shape as `predictions`.
1236
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1237
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1238
+ be either `1`, or the same as the corresponding `labels` dimension).
1239
+ metrics_collections: An optional list of collections that
1240
+ `mean_relative_error` should be added to.
1241
+ updates_collections: An optional list of collections that `update_op` should
1242
+ be added to.
1243
+ name: An optional variable_scope name.
1244
+
1245
+ Returns:
1246
+ mean_relative_error: A `Tensor` representing the current mean, the value of
1247
+ `total` divided by `count`.
1248
+ update_op: An operation that increments the `total` and `count` variables
1249
+ appropriately and whose value matches `mean_relative_error`.
1250
+
1251
+ Raises:
1252
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1253
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1254
+ either `metrics_collections` or `updates_collections` are not a list or
1255
+ tuple.
1256
+ RuntimeError: If eager execution is enabled.
1257
+ """
1258
+ if context.executing_eagerly():
1259
+ raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1260
+ 'eager execution is enabled.')
1261
+
1262
+ predictions, labels, weights = _remove_squeezable_dimensions(
1263
+ predictions=predictions, labels=labels, weights=weights)
1264
+
1265
+ predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1266
+ predictions, normalizer)
1267
+ predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1268
+ relative_errors = array_ops.where(
1269
+ math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1270
+ math_ops.div(math_ops.abs(labels - predictions), normalizer))
1271
+ return mean(relative_errors, weights, metrics_collections,
1272
+ updates_collections, name or 'mean_relative_error')
1273
+
1274
+
1275
+ @tf_export(v1=['metrics.mean_squared_error'])
1276
+ def mean_squared_error(labels,
1277
+ predictions,
1278
+ weights=None,
1279
+ metrics_collections=None,
1280
+ updates_collections=None,
1281
+ name=None):
1282
+ """Computes the mean squared error between the labels and predictions.
1283
+
1284
+ The `mean_squared_error` function creates two local variables,
1285
+ `total` and `count` that are used to compute the mean squared error.
1286
+ This average is weighted by `weights`, and it is ultimately returned as
1287
+ `mean_squared_error`: an idempotent operation that simply divides `total` by
1288
+ `count`.
1289
+
1290
+ For estimation of the metric over a stream of data, the function creates an
1291
+ `update_op` operation that updates these variables and returns the
1292
+ `mean_squared_error`. Internally, a `squared_error` operation computes the
1293
+ element-wise square of the difference between `predictions` and `labels`. Then
1294
+ `update_op` increments `total` with the reduced sum of the product of
1295
+ `weights` and `squared_error`, and it increments `count` with the reduced sum
1296
+ of `weights`.
1297
+
1298
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1299
+
1300
+ Args:
1301
+ labels: A `Tensor` of the same shape as `predictions`.
1302
+ predictions: A `Tensor` of arbitrary shape.
1303
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1304
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1305
+ be either `1`, or the same as the corresponding `labels` dimension).
1306
+ metrics_collections: An optional list of collections that
1307
+ `mean_squared_error` should be added to.
1308
+ updates_collections: An optional list of collections that `update_op` should
1309
+ be added to.
1310
+ name: An optional variable_scope name.
1311
+
1312
+ Returns:
1313
+ mean_squared_error: A `Tensor` representing the current mean, the value of
1314
+ `total` divided by `count`.
1315
+ update_op: An operation that increments the `total` and `count` variables
1316
+ appropriately and whose value matches `mean_squared_error`.
1317
+
1318
+ Raises:
1319
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1320
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1321
+ either `metrics_collections` or `updates_collections` are not a list or
1322
+ tuple.
1323
+ RuntimeError: If eager execution is enabled.
1324
+ """
1325
+ if context.executing_eagerly():
1326
+ raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1327
+ 'eager execution is enabled.')
1328
+
1329
+ predictions, labels, weights = _remove_squeezable_dimensions(
1330
+ predictions=predictions, labels=labels, weights=weights)
1331
+ squared_error = math_ops.squared_difference(labels, predictions)
1332
+ return mean(squared_error, weights, metrics_collections, updates_collections,
1333
+ name or 'mean_squared_error')
1334
+
1335
+
1336
+ @tf_export(v1=['metrics.mean_tensor'])
1337
+ def mean_tensor(values,
1338
+ weights=None,
1339
+ metrics_collections=None,
1340
+ updates_collections=None,
1341
+ name=None):
1342
+ """Computes the element-wise (weighted) mean of the given tensors.
1343
+
1344
+ In contrast to the `mean` function which returns a scalar with the
1345
+ mean, this function returns an average tensor with the same shape as the
1346
+ input tensors.
1347
+
1348
+ The `mean_tensor` function creates two local variables,
1349
+ `total_tensor` and `count_tensor` that are used to compute the average of
1350
+ `values`. This average is ultimately returned as `mean` which is an idempotent
1351
+ operation that simply divides `total` by `count`.
1352
+
1353
+ For estimation of the metric over a stream of data, the function creates an
1354
+ `update_op` operation that updates these variables and returns the `mean`.
1355
+ `update_op` increments `total` with the reduced sum of the product of `values`
1356
+ and `weights`, and it increments `count` with the reduced sum of `weights`.
1357
+
1358
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1359
+
1360
+ Args:
1361
+ values: A `Tensor` of arbitrary dimensions.
1362
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1363
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1364
+ be either `1`, or the same as the corresponding `values` dimension).
1365
+ metrics_collections: An optional list of collections that `mean`
1366
+ should be added to.
1367
+ updates_collections: An optional list of collections that `update_op`
1368
+ should be added to.
1369
+ name: An optional variable_scope name.
1370
+
1371
+ Returns:
1372
+ mean: A float `Tensor` representing the current mean, the value of `total`
1373
+ divided by `count`.
1374
+ update_op: An operation that increments the `total` and `count` variables
1375
+ appropriately and whose value matches `mean_value`.
1376
+
1377
+ Raises:
1378
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1379
+ or if either `metrics_collections` or `updates_collections` are not a list
1380
+ or tuple.
1381
+ RuntimeError: If eager execution is enabled.
1382
+ """
1383
+ if context.executing_eagerly():
1384
+ raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1385
+ 'eager execution is enabled.')
1386
+
1387
+ with variable_scope.variable_scope(name, 'mean', (values, weights)):
1388
+ values = math_ops.cast(values, dtypes.float32)
1389
+ total = metric_variable(
1390
+ values.get_shape(), dtypes.float32, name='total_tensor')
1391
+ count = metric_variable(
1392
+ values.get_shape(), dtypes.float32, name='count_tensor')
1393
+
1394
+ num_values = array_ops.ones_like(values)
1395
+ if weights is not None:
1396
+ values, _, weights = _remove_squeezable_dimensions(
1397
+ predictions=values, labels=None, weights=weights)
1398
+ weights = weights_broadcast_ops.broadcast_weights(
1399
+ math_ops.cast(weights, dtypes.float32), values)
1400
+ values = math_ops.multiply(values, weights)
1401
+ num_values = math_ops.multiply(num_values, weights)
1402
+
1403
+ update_total_op = state_ops.assign_add(total, values, use_locking=True)
1404
+ with ops.control_dependencies([values]):
1405
+ update_count_op = state_ops.assign_add(
1406
+ count, num_values, use_locking=True)
1407
+
1408
+ compute_mean = lambda _, t, c: math_ops.div_no_nan( # noqa: E731
1409
+ t, math_ops.maximum(c, 0), name='value')
1410
+
1411
+ mean_t = _aggregate_across_replicas(metrics_collections, compute_mean,
1412
+ total, count)
1413
+
1414
+ update_op = math_ops.div_no_nan(
1415
+ update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1416
+ if updates_collections:
1417
+ ops.add_to_collections(updates_collections, update_op)
1418
+
1419
+ return mean_t, update_op
1420
+
1421
+
1422
+ @tf_export(v1=['metrics.percentage_below'])
1423
+ def percentage_below(values,
1424
+ threshold,
1425
+ weights=None,
1426
+ metrics_collections=None,
1427
+ updates_collections=None,
1428
+ name=None):
1429
+ """Computes the percentage of values less than the given threshold.
1430
+
1431
+ The `percentage_below` function creates two local variables,
1432
+ `total` and `count` that are used to compute the percentage of `values` that
1433
+ fall below `threshold`. This rate is weighted by `weights`, and it is
1434
+ ultimately returned as `percentage` which is an idempotent operation that
1435
+ simply divides `total` by `count`.
1436
+
1437
+ For estimation of the metric over a stream of data, the function creates an
1438
+ `update_op` operation that updates these variables and returns the
1439
+ `percentage`.
1440
+
1441
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1442
+
1443
+ Args:
1444
+ values: A numeric `Tensor` of arbitrary size.
1445
+ threshold: A scalar threshold.
1446
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1447
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1448
+ be either `1`, or the same as the corresponding `values` dimension).
1449
+ metrics_collections: An optional list of collections that the metric
1450
+ value variable should be added to.
1451
+ updates_collections: An optional list of collections that the metric update
1452
+ ops should be added to.
1453
+ name: An optional variable_scope name.
1454
+
1455
+ Returns:
1456
+ percentage: A `Tensor` representing the current mean, the value of `total`
1457
+ divided by `count`.
1458
+ update_op: An operation that increments the `total` and `count` variables
1459
+ appropriately.
1460
+
1461
+ Raises:
1462
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1463
+ or if either `metrics_collections` or `updates_collections` are not a list
1464
+ or tuple.
1465
+ RuntimeError: If eager execution is enabled.
1466
+ """
1467
+ if context.executing_eagerly():
1468
+ raise RuntimeError('tf.metrics.percentage_below is not supported when '
1469
+ 'eager execution is enabled.')
1470
+
1471
+ is_below_threshold = math_ops.cast(
1472
+ math_ops.less(values, threshold), dtypes.float32)
1473
+ return mean(is_below_threshold, weights, metrics_collections,
1474
+ updates_collections, name or 'percentage_below_threshold')
1475
+
1476
+
1477
+ def _count_condition(values,
1478
+ weights=None,
1479
+ metrics_collections=None,
1480
+ updates_collections=None):
1481
+ """Sums the weights of cases where the given values are True.
1482
+
1483
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1484
+
1485
+ Args:
1486
+ values: A `bool` `Tensor` of arbitrary size.
1487
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1488
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1489
+ be either `1`, or the same as the corresponding `values` dimension).
1490
+ metrics_collections: An optional list of collections that the metric
1491
+ value variable should be added to.
1492
+ updates_collections: An optional list of collections that the metric update
1493
+ ops should be added to.
1494
+
1495
+ Returns:
1496
+ value_tensor: A `Tensor` representing the current value of the metric.
1497
+ update_op: An operation that accumulates the error from a batch of data.
1498
+
1499
+ Raises:
1500
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1501
+ or if either `metrics_collections` or `updates_collections` are not a list
1502
+ or tuple.
1503
+ """
1504
+ check_ops.assert_type(values, dtypes.bool)
1505
+ count = metric_variable([], dtypes.float32, name='count')
1506
+
1507
+ values = math_ops.cast(values, dtypes.float32)
1508
+ if weights is not None:
1509
+ with ops.control_dependencies(
1510
+ (check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
1511
+ weights = math_ops.cast(weights, dtypes.float32)
1512
+ values = math_ops.multiply(values, weights)
1513
+
1514
+ value_tensor = _aggregate_variable(count, metrics_collections)
1515
+
1516
+ update_op = state_ops.assign_add(
1517
+ count, math_ops.reduce_sum(values), use_locking=True)
1518
+ if updates_collections:
1519
+ ops.add_to_collections(updates_collections, update_op)
1520
+
1521
+ return value_tensor, update_op
1522
+
1523
+
1524
+ @tf_export(v1=['metrics.false_negatives'])
1525
+ def false_negatives(labels,
1526
+ predictions,
1527
+ weights=None,
1528
+ metrics_collections=None,
1529
+ updates_collections=None,
1530
+ name=None):
1531
+ """Computes the total number of false negatives.
1532
+
1533
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1534
+
1535
+ Args:
1536
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1537
+ `predictions`. Will be cast to `bool`.
1538
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1539
+ be cast to `bool`.
1540
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1541
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1542
+ be either `1`, or the same as the corresponding `labels` dimension).
1543
+ metrics_collections: An optional list of collections that the metric
1544
+ value variable should be added to.
1545
+ updates_collections: An optional list of collections that the metric update
1546
+ ops should be added to.
1547
+ name: An optional variable_scope name.
1548
+
1549
+ Returns:
1550
+ value_tensor: A `Tensor` representing the current value of the metric.
1551
+ update_op: An operation that accumulates the error from a batch of data.
1552
+
1553
+ Raises:
1554
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1555
+ or if either `metrics_collections` or `updates_collections` are not a list
1556
+ or tuple.
1557
+ RuntimeError: If eager execution is enabled.
1558
+ """
1559
+ if context.executing_eagerly():
1560
+ raise RuntimeError('tf.metrics.false_negatives is not supported when '
1561
+ 'eager execution is enabled.')
1562
+
1563
+ with variable_scope.variable_scope(name, 'false_negatives',
1564
+ (predictions, labels, weights)):
1565
+
1566
+ predictions, labels, weights = _remove_squeezable_dimensions(
1567
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1568
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1569
+ weights=weights)
1570
+ is_false_negative = math_ops.logical_and(
1571
+ math_ops.equal(labels, True), math_ops.equal(predictions, False))
1572
+ return _count_condition(is_false_negative, weights, metrics_collections,
1573
+ updates_collections)
1574
+
1575
+
1576
+ @tf_export(v1=['metrics.false_negatives_at_thresholds'])
1577
+ def false_negatives_at_thresholds(labels,
1578
+ predictions,
1579
+ thresholds,
1580
+ weights=None,
1581
+ metrics_collections=None,
1582
+ updates_collections=None,
1583
+ name=None):
1584
+ """Computes false negatives at provided threshold values.
1585
+
1586
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1587
+
1588
+ Args:
1589
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1590
+ `bool`.
1591
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1592
+ are in the range `[0, 1]`.
1593
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1594
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1595
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1596
+ be either `1`, or the same as the corresponding `labels` dimension).
1597
+ metrics_collections: An optional list of collections that `false_negatives`
1598
+ should be added to.
1599
+ updates_collections: An optional list of collections that `update_op` should
1600
+ be added to.
1601
+ name: An optional variable_scope name.
1602
+
1603
+ Returns:
1604
+ false_negatives: A float `Tensor` of shape `[len(thresholds)]`.
1605
+ update_op: An operation that updates the `false_negatives` variable and
1606
+ returns its current value.
1607
+
1608
+ Raises:
1609
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1610
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1611
+ either `metrics_collections` or `updates_collections` are not a list or
1612
+ tuple.
1613
+ RuntimeError: If eager execution is enabled.
1614
+ """
1615
+ if context.executing_eagerly():
1616
+ raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1617
+ 'supported when eager execution is enabled.')
1618
+
1619
+ with variable_scope.variable_scope(name, 'false_negatives',
1620
+ (predictions, labels, weights)):
1621
+ values, update_ops = _confusion_matrix_at_thresholds(
1622
+ labels, predictions, thresholds, weights=weights, includes=('fn',))
1623
+
1624
+ fn_value = _aggregate_variable(values['fn'], metrics_collections)
1625
+
1626
+ if updates_collections:
1627
+ ops.add_to_collections(updates_collections, update_ops['fn'])
1628
+
1629
+ return fn_value, update_ops['fn']
1630
+
1631
+
1632
+ @tf_export(v1=['metrics.false_positives'])
1633
+ def false_positives(labels,
1634
+ predictions,
1635
+ weights=None,
1636
+ metrics_collections=None,
1637
+ updates_collections=None,
1638
+ name=None):
1639
+ """Sum the weights of false positives.
1640
+
1641
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1642
+
1643
+ Args:
1644
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1645
+ `predictions`. Will be cast to `bool`.
1646
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1647
+ be cast to `bool`.
1648
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1649
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1650
+ be either `1`, or the same as the corresponding `labels` dimension).
1651
+ metrics_collections: An optional list of collections that the metric
1652
+ value variable should be added to.
1653
+ updates_collections: An optional list of collections that the metric update
1654
+ ops should be added to.
1655
+ name: An optional variable_scope name.
1656
+
1657
+ Returns:
1658
+ value_tensor: A `Tensor` representing the current value of the metric.
1659
+ update_op: An operation that accumulates the error from a batch of data.
1660
+
1661
+ Raises:
1662
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1663
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1664
+ either `metrics_collections` or `updates_collections` are not a list or
1665
+ tuple.
1666
+ RuntimeError: If eager execution is enabled.
1667
+ """
1668
+ if context.executing_eagerly():
1669
+ raise RuntimeError('tf.metrics.false_positives is not supported when '
1670
+ 'eager execution is enabled.')
1671
+
1672
+ with variable_scope.variable_scope(name, 'false_positives',
1673
+ (predictions, labels, weights)):
1674
+
1675
+ predictions, labels, weights = _remove_squeezable_dimensions(
1676
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1677
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1678
+ weights=weights)
1679
+ is_false_positive = math_ops.logical_and(
1680
+ math_ops.equal(labels, False), math_ops.equal(predictions, True))
1681
+ return _count_condition(is_false_positive, weights, metrics_collections,
1682
+ updates_collections)
1683
+
1684
+
1685
+ @tf_export(v1=['metrics.false_positives_at_thresholds'])
1686
+ def false_positives_at_thresholds(labels,
1687
+ predictions,
1688
+ thresholds,
1689
+ weights=None,
1690
+ metrics_collections=None,
1691
+ updates_collections=None,
1692
+ name=None):
1693
+ """Computes false positives at provided threshold values.
1694
+
1695
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1696
+
1697
+ Args:
1698
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1699
+ `bool`.
1700
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1701
+ are in the range `[0, 1]`.
1702
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1703
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1704
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1705
+ be either `1`, or the same as the corresponding `labels` dimension).
1706
+ metrics_collections: An optional list of collections that `false_positives`
1707
+ should be added to.
1708
+ updates_collections: An optional list of collections that `update_op` should
1709
+ be added to.
1710
+ name: An optional variable_scope name.
1711
+
1712
+ Returns:
1713
+ false_positives: A float `Tensor` of shape `[len(thresholds)]`.
1714
+ update_op: An operation that updates the `false_positives` variable and
1715
+ returns its current value.
1716
+
1717
+ Raises:
1718
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1719
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1720
+ either `metrics_collections` or `updates_collections` are not a list or
1721
+ tuple.
1722
+ RuntimeError: If eager execution is enabled.
1723
+ """
1724
+ if context.executing_eagerly():
1725
+ raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1726
+ 'supported when eager execution is enabled.')
1727
+
1728
+ with variable_scope.variable_scope(name, 'false_positives',
1729
+ (predictions, labels, weights)):
1730
+ values, update_ops = _confusion_matrix_at_thresholds(
1731
+ labels, predictions, thresholds, weights=weights, includes=('fp',))
1732
+
1733
+ fp_value = _aggregate_variable(values['fp'], metrics_collections)
1734
+
1735
+ if updates_collections:
1736
+ ops.add_to_collections(updates_collections, update_ops['fp'])
1737
+
1738
+ return fp_value, update_ops['fp']
1739
+
1740
+
1741
+ @tf_export(v1=['metrics.true_negatives'])
1742
+ def true_negatives(labels,
1743
+ predictions,
1744
+ weights=None,
1745
+ metrics_collections=None,
1746
+ updates_collections=None,
1747
+ name=None):
1748
+ """Sum the weights of true_negatives.
1749
+
1750
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1751
+
1752
+ Args:
1753
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1754
+ `predictions`. Will be cast to `bool`.
1755
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1756
+ be cast to `bool`.
1757
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1758
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1759
+ be either `1`, or the same as the corresponding `labels` dimension).
1760
+ metrics_collections: An optional list of collections that the metric
1761
+ value variable should be added to.
1762
+ updates_collections: An optional list of collections that the metric update
1763
+ ops should be added to.
1764
+ name: An optional variable_scope name.
1765
+
1766
+ Returns:
1767
+ value_tensor: A `Tensor` representing the current value of the metric.
1768
+ update_op: An operation that accumulates the error from a batch of data.
1769
+
1770
+ Raises:
1771
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1772
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1773
+ either `metrics_collections` or `updates_collections` are not a list or
1774
+ tuple.
1775
+ RuntimeError: If eager execution is enabled.
1776
+ """
1777
+ if context.executing_eagerly():
1778
+ raise RuntimeError('tf.metrics.true_negatives is not '
1779
+ 'supported when eager execution is enabled.')
1780
+
1781
+ with variable_scope.variable_scope(name, 'true_negatives',
1782
+ (predictions, labels, weights)):
1783
+
1784
+ predictions, labels, weights = _remove_squeezable_dimensions(
1785
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1786
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1787
+ weights=weights)
1788
+ is_true_negative = math_ops.logical_and(
1789
+ math_ops.equal(labels, False), math_ops.equal(predictions, False))
1790
+ return _count_condition(is_true_negative, weights, metrics_collections,
1791
+ updates_collections)
1792
+
1793
+
1794
+ @tf_export(v1=['metrics.true_negatives_at_thresholds'])
1795
+ def true_negatives_at_thresholds(labels,
1796
+ predictions,
1797
+ thresholds,
1798
+ weights=None,
1799
+ metrics_collections=None,
1800
+ updates_collections=None,
1801
+ name=None):
1802
+ """Computes true negatives at provided threshold values.
1803
+
1804
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1805
+
1806
+ Args:
1807
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1808
+ `bool`.
1809
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1810
+ are in the range `[0, 1]`.
1811
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1812
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1813
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1814
+ be either `1`, or the same as the corresponding `labels` dimension).
1815
+ metrics_collections: An optional list of collections that `true_negatives`
1816
+ should be added to.
1817
+ updates_collections: An optional list of collections that `update_op` should
1818
+ be added to.
1819
+ name: An optional variable_scope name.
1820
+
1821
+ Returns:
1822
+ true_negatives: A float `Tensor` of shape `[len(thresholds)]`.
1823
+ update_op: An operation that updates the `true_negatives` variable and
1824
+ returns its current value.
1825
+
1826
+ Raises:
1827
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1828
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1829
+ either `metrics_collections` or `updates_collections` are not a list or
1830
+ tuple.
1831
+ RuntimeError: If eager execution is enabled.
1832
+ """
1833
+ if context.executing_eagerly():
1834
+ raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
1835
+ 'supported when eager execution is enabled.')
1836
+
1837
+ with variable_scope.variable_scope(name, 'true_negatives',
1838
+ (predictions, labels, weights)):
1839
+ values, update_ops = _confusion_matrix_at_thresholds(
1840
+ labels, predictions, thresholds, weights=weights, includes=('tn',))
1841
+
1842
+ tn_value = _aggregate_variable(values['tn'], metrics_collections)
1843
+
1844
+ if updates_collections:
1845
+ ops.add_to_collections(updates_collections, update_ops['tn'])
1846
+
1847
+ return tn_value, update_ops['tn']
1848
+
1849
+
1850
+ @tf_export(v1=['metrics.true_positives'])
1851
+ def true_positives(labels,
1852
+ predictions,
1853
+ weights=None,
1854
+ metrics_collections=None,
1855
+ updates_collections=None,
1856
+ name=None):
1857
+ """Sum the weights of true_positives.
1858
+
1859
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1860
+
1861
+ Args:
1862
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1863
+ `predictions`. Will be cast to `bool`.
1864
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1865
+ be cast to `bool`.
1866
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1867
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1868
+ be either `1`, or the same as the corresponding `labels` dimension).
1869
+ metrics_collections: An optional list of collections that the metric
1870
+ value variable should be added to.
1871
+ updates_collections: An optional list of collections that the metric update
1872
+ ops should be added to.
1873
+ name: An optional variable_scope name.
1874
+
1875
+ Returns:
1876
+ value_tensor: A `Tensor` representing the current value of the metric.
1877
+ update_op: An operation that accumulates the error from a batch of data.
1878
+
1879
+ Raises:
1880
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1881
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1882
+ either `metrics_collections` or `updates_collections` are not a list or
1883
+ tuple.
1884
+ RuntimeError: If eager execution is enabled.
1885
+ """
1886
+ if context.executing_eagerly():
1887
+ raise RuntimeError('tf.metrics.true_positives is not '
1888
+ 'supported when eager execution is enabled.')
1889
+
1890
+ with variable_scope.variable_scope(name, 'true_positives',
1891
+ (predictions, labels, weights)):
1892
+
1893
+ predictions, labels, weights = _remove_squeezable_dimensions(
1894
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1895
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1896
+ weights=weights)
1897
+ is_true_positive = math_ops.logical_and(
1898
+ math_ops.equal(labels, True), math_ops.equal(predictions, True))
1899
+ return _count_condition(is_true_positive, weights, metrics_collections,
1900
+ updates_collections)
1901
+
1902
+
1903
+ @tf_export(v1=['metrics.true_positives_at_thresholds'])
1904
+ def true_positives_at_thresholds(labels,
1905
+ predictions,
1906
+ thresholds,
1907
+ weights=None,
1908
+ metrics_collections=None,
1909
+ updates_collections=None,
1910
+ name=None):
1911
+ """Computes true positives at provided threshold values.
1912
+
1913
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1914
+
1915
+ Args:
1916
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1917
+ `bool`.
1918
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1919
+ are in the range `[0, 1]`.
1920
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1921
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1922
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1923
+ be either `1`, or the same as the corresponding `labels` dimension).
1924
+ metrics_collections: An optional list of collections that `true_positives`
1925
+ should be added to.
1926
+ updates_collections: An optional list of collections that `update_op` should
1927
+ be added to.
1928
+ name: An optional variable_scope name.
1929
+
1930
+ Returns:
1931
+ true_positives: A float `Tensor` of shape `[len(thresholds)]`.
1932
+ update_op: An operation that updates the `true_positives` variable and
1933
+ returns its current value.
1934
+
1935
+ Raises:
1936
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1937
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1938
+ either `metrics_collections` or `updates_collections` are not a list or
1939
+ tuple.
1940
+ RuntimeError: If eager execution is enabled.
1941
+ """
1942
+ if context.executing_eagerly():
1943
+ raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
1944
+ 'supported when eager execution is enabled.')
1945
+
1946
+ with variable_scope.variable_scope(name, 'true_positives',
1947
+ (predictions, labels, weights)):
1948
+ values, update_ops = _confusion_matrix_at_thresholds(
1949
+ labels, predictions, thresholds, weights=weights, includes=('tp',))
1950
+
1951
+ tp_value = _aggregate_variable(values['tp'], metrics_collections)
1952
+
1953
+ if updates_collections:
1954
+ ops.add_to_collections(updates_collections, update_ops['tp'])
1955
+
1956
+ return tp_value, update_ops['tp']
1957
+
1958
+
1959
+ @tf_export(v1=['metrics.precision'])
1960
+ def precision(labels,
1961
+ predictions,
1962
+ weights=None,
1963
+ metrics_collections=None,
1964
+ updates_collections=None,
1965
+ name=None):
1966
+ """Computes the precision of the predictions with respect to the labels.
1967
+
1968
+ The `precision` function creates two local variables,
1969
+ `true_positives` and `false_positives`, that are used to compute the
1970
+ precision. This value is ultimately returned as `precision`, an idempotent
1971
+ operation that simply divides `true_positives` by the sum of `true_positives`
1972
+ and `false_positives`.
1973
+
1974
+ For estimation of the metric over a stream of data, the function creates an
1975
+ `update_op` operation that updates these variables and returns the
1976
+ `precision`. `update_op` weights each prediction by the corresponding value in
1977
+ `weights`.
1978
+
1979
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1980
+
1981
+ Args:
1982
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1983
+ `predictions`. Will be cast to `bool`.
1984
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1985
+ be cast to `bool`.
1986
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1987
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1988
+ be either `1`, or the same as the corresponding `labels` dimension).
1989
+ metrics_collections: An optional list of collections that `precision` should
1990
+ be added to.
1991
+ updates_collections: An optional list of collections that `update_op` should
1992
+ be added to.
1993
+ name: An optional variable_scope name.
1994
+
1995
+ Returns:
1996
+ precision: Scalar float `Tensor` with the value of `true_positives`
1997
+ divided by the sum of `true_positives` and `false_positives`.
1998
+ update_op: `Operation` that increments `true_positives` and
1999
+ `false_positives` variables appropriately and whose value matches
2000
+ `precision`.
2001
+
2002
+ Raises:
2003
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2004
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2005
+ either `metrics_collections` or `updates_collections` are not a list or
2006
+ tuple.
2007
+ RuntimeError: If eager execution is enabled.
2008
+ """
2009
+ if context.executing_eagerly():
2010
+ raise RuntimeError('tf.metrics.precision is not '
2011
+ 'supported when eager execution is enabled.')
2012
+
2013
+ with variable_scope.variable_scope(name, 'precision',
2014
+ (predictions, labels, weights)):
2015
+
2016
+ predictions, labels, weights = _remove_squeezable_dimensions(
2017
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2018
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
2019
+ weights=weights)
2020
+
2021
+ true_p, true_positives_update_op = true_positives(
2022
+ labels,
2023
+ predictions,
2024
+ weights,
2025
+ metrics_collections=None,
2026
+ updates_collections=None,
2027
+ name=None)
2028
+ false_p, false_positives_update_op = false_positives(
2029
+ labels,
2030
+ predictions,
2031
+ weights,
2032
+ metrics_collections=None,
2033
+ updates_collections=None,
2034
+ name=None)
2035
+
2036
+ def compute_precision(tp, fp, name):
2037
+ return array_ops.where(
2038
+ math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
2039
+
2040
+ def once_across_replicas(_, true_p, false_p):
2041
+ return compute_precision(true_p, false_p, 'value')
2042
+
2043
+ p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2044
+ true_p, false_p)
2045
+
2046
+ update_op = compute_precision(true_positives_update_op,
2047
+ false_positives_update_op, 'update_op')
2048
+ if updates_collections:
2049
+ ops.add_to_collections(updates_collections, update_op)
2050
+
2051
+ return p, update_op
2052
+
2053
+
2054
+ @tf_export(v1=['metrics.precision_at_thresholds'])
2055
+ def precision_at_thresholds(labels,
2056
+ predictions,
2057
+ thresholds,
2058
+ weights=None,
2059
+ metrics_collections=None,
2060
+ updates_collections=None,
2061
+ name=None):
2062
+ """Computes precision values for different `thresholds` on `predictions`.
2063
+
2064
+ The `precision_at_thresholds` function creates four local variables,
2065
+ `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2066
+ for various values of thresholds. `precision[i]` is defined as the total
2067
+ weight of values in `predictions` above `thresholds[i]` whose corresponding
2068
+ entry in `labels` is `True`, divided by the total weight of values in
2069
+ `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2070
+ false_positives[i])`).
2071
+
2072
+ For estimation of the metric over a stream of data, the function creates an
2073
+ `update_op` operation that updates these variables and returns the
2074
+ `precision`.
2075
+
2076
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2077
+
2078
+ Args:
2079
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2080
+ `predictions`. Will be cast to `bool`.
2081
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2082
+ are in the range `[0, 1]`.
2083
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2084
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2085
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2086
+ be either `1`, or the same as the corresponding `labels` dimension).
2087
+ metrics_collections: An optional list of collections that `auc` should be
2088
+ added to.
2089
+ updates_collections: An optional list of collections that `update_op` should
2090
+ be added to.
2091
+ name: An optional variable_scope name.
2092
+
2093
+ Returns:
2094
+ precision: A float `Tensor` of shape `[len(thresholds)]`.
2095
+ update_op: An operation that increments the `true_positives`,
2096
+ `true_negatives`, `false_positives` and `false_negatives` variables that
2097
+ are used in the computation of `precision`.
2098
+
2099
+ Raises:
2100
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2101
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2102
+ either `metrics_collections` or `updates_collections` are not a list or
2103
+ tuple.
2104
+ RuntimeError: If eager execution is enabled.
2105
+ """
2106
+ if context.executing_eagerly():
2107
+ raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2108
+ 'supported when eager execution is enabled.')
2109
+
2110
+ with variable_scope.variable_scope(name, 'precision_at_thresholds',
2111
+ (predictions, labels, weights)):
2112
+ values, update_ops = _confusion_matrix_at_thresholds(
2113
+ labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2114
+
2115
+ # Avoid division by zero.
2116
+ epsilon = 1e-7
2117
+
2118
+ def compute_precision(tp, fp, name):
2119
+ return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
2120
+
2121
+ def precision_across_replicas(_, values):
2122
+ return compute_precision(values['tp'], values['fp'], 'value')
2123
+
2124
+ prec = _aggregate_across_replicas(metrics_collections,
2125
+ precision_across_replicas, values)
2126
+
2127
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2128
+ 'update_op')
2129
+ if updates_collections:
2130
+ ops.add_to_collections(updates_collections, update_op)
2131
+
2132
+ return prec, update_op
2133
+
2134
+
2135
+ @tf_export(v1=['metrics.recall'])
2136
+ def recall(labels,
2137
+ predictions,
2138
+ weights=None,
2139
+ metrics_collections=None,
2140
+ updates_collections=None,
2141
+ name=None):
2142
+ """Computes the recall of the predictions with respect to the labels.
2143
+
2144
+ The `recall` function creates two local variables, `true_positives`
2145
+ and `false_negatives`, that are used to compute the recall. This value is
2146
+ ultimately returned as `recall`, an idempotent operation that simply divides
2147
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
2148
+
2149
+ For estimation of the metric over a stream of data, the function creates an
2150
+ `update_op` that updates these variables and returns the `recall`. `update_op`
2151
+ weights each prediction by the corresponding value in `weights`.
2152
+
2153
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2154
+
2155
+ Args:
2156
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2157
+ `predictions`. Will be cast to `bool`.
2158
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2159
+ be cast to `bool`.
2160
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2161
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2162
+ be either `1`, or the same as the corresponding `labels` dimension).
2163
+ metrics_collections: An optional list of collections that `recall` should
2164
+ be added to.
2165
+ updates_collections: An optional list of collections that `update_op` should
2166
+ be added to.
2167
+ name: An optional variable_scope name.
2168
+
2169
+ Returns:
2170
+ recall: Scalar float `Tensor` with the value of `true_positives` divided
2171
+ by the sum of `true_positives` and `false_negatives`.
2172
+ update_op: `Operation` that increments `true_positives` and
2173
+ `false_negatives` variables appropriately and whose value matches
2174
+ `recall`.
2175
+
2176
+ Raises:
2177
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2178
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2179
+ either `metrics_collections` or `updates_collections` are not a list or
2180
+ tuple.
2181
+ RuntimeError: If eager execution is enabled.
2182
+ """
2183
+ if context.executing_eagerly():
2184
+ raise RuntimeError('tf.metrics.recall is not supported is not '
2185
+ 'supported when eager execution is enabled.')
2186
+
2187
+ with variable_scope.variable_scope(name, 'recall',
2188
+ (predictions, labels, weights)):
2189
+ predictions, labels, weights = _remove_squeezable_dimensions(
2190
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2191
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
2192
+ weights=weights)
2193
+
2194
+ true_p, true_positives_update_op = true_positives(
2195
+ labels,
2196
+ predictions,
2197
+ weights,
2198
+ metrics_collections=None,
2199
+ updates_collections=None,
2200
+ name=None)
2201
+ false_n, false_negatives_update_op = false_negatives(
2202
+ labels,
2203
+ predictions,
2204
+ weights,
2205
+ metrics_collections=None,
2206
+ updates_collections=None,
2207
+ name=None)
2208
+
2209
+ def compute_recall(true_p, false_n, name):
2210
+ return array_ops.where(
2211
+ math_ops.greater(true_p + false_n, 0),
2212
+ math_ops.div(true_p, true_p + false_n), 0, name)
2213
+
2214
+ def once_across_replicas(_, true_p, false_n):
2215
+ return compute_recall(true_p, false_n, 'value')
2216
+
2217
+ rec = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2218
+ true_p, false_n)
2219
+
2220
+ update_op = compute_recall(true_positives_update_op,
2221
+ false_negatives_update_op, 'update_op')
2222
+ if updates_collections:
2223
+ ops.add_to_collections(updates_collections, update_op)
2224
+
2225
+ return rec, update_op
2226
+
2227
+
2228
+ def _at_k_name(name, k=None, class_id=None):
2229
+ if k is not None:
2230
+ name = '%s_at_%d' % (name, k)
2231
+ else:
2232
+ name = '%s_at_k' % (name)
2233
+ if class_id is not None:
2234
+ name = '%s_class%d' % (name, class_id)
2235
+ return name
2236
+
2237
+
2238
+ def _select_class_id(ids, selected_id):
2239
+ """Filter all but `selected_id` out of `ids`.
2240
+
2241
+ Args:
2242
+ ids: `int64` `Tensor` or `SparseTensor` of IDs.
2243
+ selected_id: Int id to select.
2244
+
2245
+ Returns:
2246
+ `SparseTensor` of same dimensions as `ids`. This contains only the entries
2247
+ equal to `selected_id`.
2248
+ """
2249
+ ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2250
+ if isinstance(ids, sparse_tensor.SparseTensor):
2251
+ return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2252
+ selected_id))
2253
+
2254
+ # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2255
+ # tf.equal and tf.reduce_any?
2256
+
2257
+ # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2258
+ ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2259
+ ids_last_dim = array_ops.size(ids_shape) - 1
2260
+ filled_selected_id_shape = math_ops.reduced_shape(
2261
+ ids_shape, array_ops.reshape(ids_last_dim, [1]))
2262
+
2263
+ # Intersect `ids` with the selected ID.
2264
+ filled_selected_id = array_ops.fill(filled_selected_id_shape,
2265
+ math_ops.cast(selected_id, dtypes.int64))
2266
+ result = sets.set_intersection(filled_selected_id, ids)
2267
+ return sparse_tensor.SparseTensor(
2268
+ indices=result.indices, values=result.values, dense_shape=ids_shape)
2269
+
2270
+
2271
+ def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2272
+ """If class ID is specified, filter all other classes.
2273
+
2274
+ Args:
2275
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2276
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2277
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2278
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2279
+ `predictions_idx`.
2280
+ predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2281
+ where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2282
+ [batch size, k].
2283
+ selected_id: Int id to select.
2284
+
2285
+ Returns:
2286
+ Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2287
+ """
2288
+ if selected_id is None:
2289
+ return labels, predictions_idx
2290
+ return (_select_class_id(labels, selected_id),
2291
+ _select_class_id(predictions_idx, selected_id))
2292
+
2293
+
2294
+ def _sparse_true_positive_at_k(labels,
2295
+ predictions_idx,
2296
+ class_id=None,
2297
+ weights=None,
2298
+ name=None):
2299
+ """Calculates true positives for recall@k and precision@k.
2300
+
2301
+ If `class_id` is specified, calculate binary true positives for `class_id`
2302
+ only.
2303
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2304
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2305
+
2306
+ Args:
2307
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2308
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2309
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2310
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2311
+ `predictions_idx`.
2312
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2313
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2314
+ match `labels`.
2315
+ class_id: Class for which we want binary metrics.
2316
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2317
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2318
+ dimensions must be either `1`, or the same as the corresponding `labels`
2319
+ dimension).
2320
+ name: Name of operation.
2321
+
2322
+ Returns:
2323
+ A [D1, ... DN] `Tensor` of true positive counts.
2324
+ """
2325
+ with ops.name_scope(name, 'true_positives',
2326
+ (predictions_idx, labels, weights)):
2327
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2328
+ class_id)
2329
+ tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2330
+ tp = math_ops.cast(tp, dtypes.float64)
2331
+ if weights is not None:
2332
+ with ops.control_dependencies(
2333
+ (weights_broadcast_ops.assert_broadcastable(weights, tp),)):
2334
+ weights = math_ops.cast(weights, dtypes.float64)
2335
+ tp = math_ops.multiply(tp, weights)
2336
+ return tp
2337
+
2338
+
2339
+ def _streaming_sparse_true_positive_at_k(labels,
2340
+ predictions_idx,
2341
+ k=None,
2342
+ class_id=None,
2343
+ weights=None,
2344
+ name=None):
2345
+ """Calculates weighted per step true positives for recall@k and precision@k.
2346
+
2347
+ If `class_id` is specified, calculate binary true positives for `class_id`
2348
+ only.
2349
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2350
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
2351
+
2352
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2353
+
2354
+ Args:
2355
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2356
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2357
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2358
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2359
+ `predictions_idx`.
2360
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2361
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2362
+ match `labels`.
2363
+ k: Integer, k for @k metric. This is only used for default op name.
2364
+ class_id: Class for which we want binary metrics.
2365
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2366
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2367
+ dimensions must be either `1`, or the same as the corresponding `labels`
2368
+ dimension).
2369
+ name: Name of new variable, and namespace for other dependent ops.
2370
+
2371
+ Returns:
2372
+ A tuple of `Variable` and update `Operation`.
2373
+
2374
+ Raises:
2375
+ ValueError: If `weights` is not `None` and has an incompatible shape.
2376
+ """
2377
+ with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2378
+ (predictions_idx, labels, weights)) as scope:
2379
+ tp = _sparse_true_positive_at_k(
2380
+ predictions_idx=predictions_idx,
2381
+ labels=labels,
2382
+ class_id=class_id,
2383
+ weights=weights)
2384
+ batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2385
+
2386
+ var = metric_variable([], dtypes.float64, name=scope)
2387
+ return var, state_ops.assign_add(
2388
+ var, batch_total_tp, name='update', use_locking=True)
2389
+
2390
+
2391
+ def _sparse_false_negative_at_k(labels,
2392
+ predictions_idx,
2393
+ class_id=None,
2394
+ weights=None):
2395
+ """Calculates false negatives for recall@k.
2396
+
2397
+ If `class_id` is specified, calculate binary true positives for `class_id`
2398
+ only.
2399
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2400
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2401
+
2402
+ Args:
2403
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2404
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2405
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2406
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2407
+ `predictions_idx`.
2408
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2409
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2410
+ match `labels`.
2411
+ class_id: Class for which we want binary metrics.
2412
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2413
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2414
+ dimensions must be either `1`, or the same as the corresponding `labels`
2415
+ dimension).
2416
+
2417
+ Returns:
2418
+ A [D1, ... DN] `Tensor` of false negative counts.
2419
+ """
2420
+ with ops.name_scope(None, 'false_negatives',
2421
+ (predictions_idx, labels, weights)):
2422
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2423
+ class_id)
2424
+ fn = sets.set_size(
2425
+ sets.set_difference(predictions_idx, labels, aminusb=False))
2426
+ fn = math_ops.cast(fn, dtypes.float64)
2427
+ if weights is not None:
2428
+ with ops.control_dependencies(
2429
+ (weights_broadcast_ops.assert_broadcastable(weights, fn),)):
2430
+ weights = math_ops.cast(weights, dtypes.float64)
2431
+ fn = math_ops.multiply(fn, weights)
2432
+ return fn
2433
+
2434
+
2435
+ def _streaming_sparse_false_negative_at_k(labels,
2436
+ predictions_idx,
2437
+ k,
2438
+ class_id=None,
2439
+ weights=None,
2440
+ name=None):
2441
+ """Calculates weighted per step false negatives for recall@k.
2442
+
2443
+ If `class_id` is specified, calculate binary true positives for `class_id`
2444
+ only.
2445
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2446
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
2447
+
2448
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2449
+
2450
+ Args:
2451
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2452
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2453
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2454
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2455
+ `predictions_idx`.
2456
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2457
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2458
+ match `labels`.
2459
+ k: Integer, k for @k metric. This is only used for default op name.
2460
+ class_id: Class for which we want binary metrics.
2461
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2462
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2463
+ dimensions must be either `1`, or the same as the corresponding `labels`
2464
+ dimension).
2465
+ name: Name of new variable, and namespace for other dependent ops.
2466
+
2467
+ Returns:
2468
+ A tuple of `Variable` and update `Operation`.
2469
+
2470
+ Raises:
2471
+ ValueError: If `weights` is not `None` and has an incompatible shape.
2472
+ """
2473
+ with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2474
+ (predictions_idx, labels, weights)) as scope:
2475
+ fn = _sparse_false_negative_at_k(
2476
+ predictions_idx=predictions_idx,
2477
+ labels=labels,
2478
+ class_id=class_id,
2479
+ weights=weights)
2480
+ batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2481
+
2482
+ var = metric_variable([], dtypes.float64, name=scope)
2483
+ return var, state_ops.assign_add(
2484
+ var, batch_total_fn, name='update', use_locking=True)
2485
+
2486
+
2487
+ @tf_export(v1=['metrics.recall_at_k'])
2488
+ def recall_at_k(labels,
2489
+ predictions,
2490
+ k,
2491
+ class_id=None,
2492
+ weights=None,
2493
+ metrics_collections=None,
2494
+ updates_collections=None,
2495
+ name=None):
2496
+ """Computes recall@k of the predictions with respect to sparse labels.
2497
+
2498
+ If `class_id` is specified, we calculate recall by considering only the
2499
+ entries in the batch for which `class_id` is in the label, and computing
2500
+ the fraction of them for which `class_id` is in the top-k `predictions`.
2501
+ If `class_id` is not specified, we'll calculate recall as how often on
2502
+ average a class among the labels of a batch entry is in the top-k
2503
+ `predictions`.
2504
+
2505
+ `sparse_recall_at_k` creates two local variables,
2506
+ `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2507
+ the recall_at_k frequency. This frequency is ultimately returned as
2508
+ `recall_at_<k>`: an idempotent operation that simply divides
2509
+ `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2510
+ `false_negative_at_<k>`).
2511
+
2512
+ For estimation of the metric over a stream of data, the function creates an
2513
+ `update_op` operation that updates these variables and returns the
2514
+ `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2515
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
2516
+ `labels` calculate the true positives and false negatives weighted by
2517
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
2518
+ `false_negative_at_<k>` using these values.
2519
+
2520
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2521
+
2522
+ Args:
2523
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2524
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2525
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
2526
+ the associated prediction. Commonly, N=1 and `labels` has shape
2527
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2528
+ should be in range [0, num_classes), where num_classes is the last
2529
+ dimension of `predictions`. Values outside this range always count
2530
+ towards `false_negative_at_<k>`.
2531
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2532
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2533
+ The final dimension contains the logit values for each class. [D1, ... DN]
2534
+ must match `labels`.
2535
+ k: Integer, k for @k metric.
2536
+ class_id: Integer class ID for which we want binary metrics. This should be
2537
+ in range [0, num_classes), where num_classes is the last dimension of
2538
+ `predictions`. If class_id is outside this range, the method returns NAN.
2539
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2540
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2541
+ dimensions must be either `1`, or the same as the corresponding `labels`
2542
+ dimension).
2543
+ metrics_collections: An optional list of collections that values should
2544
+ be added to.
2545
+ updates_collections: An optional list of collections that updates should
2546
+ be added to.
2547
+ name: Name of new update operation, and namespace for other dependent ops.
2548
+
2549
+ Returns:
2550
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2551
+ by the sum of `true_positives` and `false_negatives`.
2552
+ update_op: `Operation` that increments `true_positives` and
2553
+ `false_negatives` variables appropriately, and whose value matches
2554
+ `recall`.
2555
+
2556
+ Raises:
2557
+ ValueError: If `weights` is not `None` and its shape doesn't match
2558
+ `predictions`, or if either `metrics_collections` or `updates_collections`
2559
+ are not a list or tuple.
2560
+ RuntimeError: If eager execution is enabled.
2561
+ """
2562
+ if context.executing_eagerly():
2563
+ raise RuntimeError('tf.metrics.recall_at_k is not '
2564
+ 'supported when eager execution is enabled.')
2565
+
2566
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2567
+ (predictions, labels, weights)) as scope:
2568
+ _, top_k_idx = nn.top_k(predictions, k)
2569
+ return recall_at_top_k(
2570
+ labels=labels,
2571
+ predictions_idx=top_k_idx,
2572
+ k=k,
2573
+ class_id=class_id,
2574
+ weights=weights,
2575
+ metrics_collections=metrics_collections,
2576
+ updates_collections=updates_collections,
2577
+ name=scope)
2578
+
2579
+
2580
+ @tf_export(v1=['metrics.recall_at_top_k'])
2581
+ def recall_at_top_k(labels,
2582
+ predictions_idx,
2583
+ k=None,
2584
+ class_id=None,
2585
+ weights=None,
2586
+ metrics_collections=None,
2587
+ updates_collections=None,
2588
+ name=None):
2589
+ """Computes recall@k of top-k predictions with respect to sparse labels.
2590
+
2591
+ Differs from `recall_at_k` in that predictions must be in the form of top `k`
2592
+ class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2593
+ for more details.
2594
+
2595
+ Args:
2596
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2597
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2598
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
2599
+ the associated prediction. Commonly, N=1 and `labels` has shape
2600
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2601
+ should be in range [0, num_classes), where num_classes is the last
2602
+ dimension of `predictions`. Values outside this range always count
2603
+ towards `false_negative_at_<k>`.
2604
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2605
+ Commonly, N=1 and predictions has shape [batch size, k]. The final
2606
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
2607
+ match `labels`.
2608
+ k: Integer, k for @k metric. Only used for the default op name.
2609
+ class_id: Integer class ID for which we want binary metrics. This should be
2610
+ in range [0, num_classes), where num_classes is the last dimension of
2611
+ `predictions`. If class_id is outside this range, the method returns NAN.
2612
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2613
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2614
+ dimensions must be either `1`, or the same as the corresponding `labels`
2615
+ dimension).
2616
+ metrics_collections: An optional list of collections that values should
2617
+ be added to.
2618
+ updates_collections: An optional list of collections that updates should
2619
+ be added to.
2620
+ name: Name of new update operation, and namespace for other dependent ops.
2621
+
2622
+ Returns:
2623
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2624
+ by the sum of `true_positives` and `false_negatives`.
2625
+ update_op: `Operation` that increments `true_positives` and
2626
+ `false_negatives` variables appropriately, and whose value matches
2627
+ `recall`.
2628
+
2629
+ Raises:
2630
+ ValueError: If `weights` is not `None` and its shape doesn't match
2631
+ `predictions`, or if either `metrics_collections` or `updates_collections`
2632
+ are not a list or tuple.
2633
+ """
2634
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2635
+ (predictions_idx, labels, weights)) as scope:
2636
+ labels = _maybe_expand_labels(labels, predictions_idx)
2637
+ top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2638
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
2639
+ predictions_idx=top_k_idx,
2640
+ labels=labels,
2641
+ k=k,
2642
+ class_id=class_id,
2643
+ weights=weights)
2644
+ fn, fn_update = _streaming_sparse_false_negative_at_k(
2645
+ predictions_idx=top_k_idx,
2646
+ labels=labels,
2647
+ k=k,
2648
+ class_id=class_id,
2649
+ weights=weights)
2650
+
2651
+ def compute_recall(_, tp, fn):
2652
+ return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
2653
+
2654
+ metric = _aggregate_across_replicas(metrics_collections, compute_recall, tp,
2655
+ fn)
2656
+
2657
+ update = math_ops.div(
2658
+ tp_update, math_ops.add(tp_update, fn_update), name='update')
2659
+ if updates_collections:
2660
+ ops.add_to_collections(updates_collections, update)
2661
+ return metric, update
2662
+
2663
+
2664
+ @tf_export(v1=['metrics.recall_at_thresholds'])
2665
+ def recall_at_thresholds(labels,
2666
+ predictions,
2667
+ thresholds,
2668
+ weights=None,
2669
+ metrics_collections=None,
2670
+ updates_collections=None,
2671
+ name=None):
2672
+ """Computes various recall values for different `thresholds` on `predictions`.
2673
+
2674
+ The `recall_at_thresholds` function creates four local variables,
2675
+ `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2676
+ for various values of thresholds. `recall[i]` is defined as the total weight
2677
+ of values in `predictions` above `thresholds[i]` whose corresponding entry in
2678
+ `labels` is `True`, divided by the total weight of `True` values in `labels`
2679
+ (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2680
+
2681
+ For estimation of the metric over a stream of data, the function creates an
2682
+ `update_op` operation that updates these variables and returns the `recall`.
2683
+
2684
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2685
+
2686
+ Args:
2687
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2688
+ `predictions`. Will be cast to `bool`.
2689
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2690
+ are in the range `[0, 1]`.
2691
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2692
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2693
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2694
+ be either `1`, or the same as the corresponding `labels` dimension).
2695
+ metrics_collections: An optional list of collections that `recall` should be
2696
+ added to.
2697
+ updates_collections: An optional list of collections that `update_op` should
2698
+ be added to.
2699
+ name: An optional variable_scope name.
2700
+
2701
+ Returns:
2702
+ recall: A float `Tensor` of shape `[len(thresholds)]`.
2703
+ update_op: An operation that increments the `true_positives`,
2704
+ `true_negatives`, `false_positives` and `false_negatives` variables that
2705
+ are used in the computation of `recall`.
2706
+
2707
+ Raises:
2708
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2709
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2710
+ either `metrics_collections` or `updates_collections` are not a list or
2711
+ tuple.
2712
+ RuntimeError: If eager execution is enabled.
2713
+ """
2714
+ if context.executing_eagerly():
2715
+ raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2716
+ 'supported when eager execution is enabled.')
2717
+
2718
+ with variable_scope.variable_scope(name, 'recall_at_thresholds',
2719
+ (predictions, labels, weights)):
2720
+ values, update_ops = _confusion_matrix_at_thresholds(
2721
+ labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2722
+
2723
+ # Avoid division by zero.
2724
+ epsilon = 1e-7
2725
+
2726
+ def compute_recall(tp, fn, name):
2727
+ return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
2728
+
2729
+ def recall_across_replicas(_, values):
2730
+ return compute_recall(values['tp'], values['fn'], 'value')
2731
+
2732
+ rec = _aggregate_across_replicas(metrics_collections,
2733
+ recall_across_replicas, values)
2734
+
2735
+ update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2736
+ if updates_collections:
2737
+ ops.add_to_collections(updates_collections, update_op)
2738
+
2739
+ return rec, update_op
2740
+
2741
+
2742
+ @tf_export(v1=['metrics.root_mean_squared_error'])
2743
+ def root_mean_squared_error(labels,
2744
+ predictions,
2745
+ weights=None,
2746
+ metrics_collections=None,
2747
+ updates_collections=None,
2748
+ name=None):
2749
+ """Computes the root mean squared error between the labels and predictions.
2750
+
2751
+ The `root_mean_squared_error` function creates two local variables,
2752
+ `total` and `count` that are used to compute the root mean squared error.
2753
+ This average is weighted by `weights`, and it is ultimately returned as
2754
+ `root_mean_squared_error`: an idempotent operation that takes the square root
2755
+ of the division of `total` by `count`.
2756
+
2757
+ For estimation of the metric over a stream of data, the function creates an
2758
+ `update_op` operation that updates these variables and returns the
2759
+ `root_mean_squared_error`. Internally, a `squared_error` operation computes
2760
+ the element-wise square of the difference between `predictions` and `labels`.
2761
+ Then `update_op` increments `total` with the reduced sum of the product of
2762
+ `weights` and `squared_error`, and it increments `count` with the reduced sum
2763
+ of `weights`.
2764
+
2765
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2766
+
2767
+ Args:
2768
+ labels: A `Tensor` of the same shape as `predictions`.
2769
+ predictions: A `Tensor` of arbitrary shape.
2770
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2771
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2772
+ be either `1`, or the same as the corresponding `labels` dimension).
2773
+ metrics_collections: An optional list of collections that
2774
+ `root_mean_squared_error` should be added to.
2775
+ updates_collections: An optional list of collections that `update_op` should
2776
+ be added to.
2777
+ name: An optional variable_scope name.
2778
+
2779
+ Returns:
2780
+ root_mean_squared_error: A `Tensor` representing the current mean, the value
2781
+ of `total` divided by `count`.
2782
+ update_op: An operation that increments the `total` and `count` variables
2783
+ appropriately and whose value matches `root_mean_squared_error`.
2784
+
2785
+ Raises:
2786
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2787
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2788
+ either `metrics_collections` or `updates_collections` are not a list or
2789
+ tuple.
2790
+ RuntimeError: If eager execution is enabled.
2791
+ """
2792
+ if context.executing_eagerly():
2793
+ raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2794
+ 'supported when eager execution is enabled.')
2795
+
2796
+ predictions, labels, weights = _remove_squeezable_dimensions(
2797
+ predictions=predictions, labels=labels, weights=weights)
2798
+ mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2799
+ None, name or
2800
+ 'root_mean_squared_error')
2801
+
2802
+ once_across_replicas = lambda _, mse: math_ops.sqrt(mse) # noqa: E731
2803
+ rmse = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2804
+ mse)
2805
+
2806
+ update_rmse_op = math_ops.sqrt(update_mse_op)
2807
+ if updates_collections:
2808
+ ops.add_to_collections(updates_collections, update_rmse_op)
2809
+
2810
+ return rmse, update_rmse_op
2811
+
2812
+
2813
+ @tf_export(v1=['metrics.sensitivity_at_specificity'])
2814
+ def sensitivity_at_specificity(labels,
2815
+ predictions,
2816
+ specificity,
2817
+ weights=None,
2818
+ num_thresholds=200,
2819
+ metrics_collections=None,
2820
+ updates_collections=None,
2821
+ name=None):
2822
+ """Computes the specificity at a given sensitivity.
2823
+
2824
+ The `sensitivity_at_specificity` function creates four local
2825
+ variables, `true_positives`, `true_negatives`, `false_positives` and
2826
+ `false_negatives` that are used to compute the sensitivity at the given
2827
+ specificity value. The threshold for the given specificity value is computed
2828
+ and used to evaluate the corresponding sensitivity.
2829
+
2830
+ For estimation of the metric over a stream of data, the function creates an
2831
+ `update_op` operation that updates these variables and returns the
2832
+ `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
2833
+ `false_positives` and `false_negatives` counts with the weight of each case
2834
+ found in the `predictions` and `labels`.
2835
+
2836
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2837
+
2838
+ For additional information about specificity and sensitivity, see the
2839
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
2840
+
2841
+ Args:
2842
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2843
+ `predictions`. Will be cast to `bool`.
2844
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2845
+ are in the range `[0, 1]`.
2846
+ specificity: A scalar value in range `[0, 1]`.
2847
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2848
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2849
+ be either `1`, or the same as the corresponding `labels` dimension).
2850
+ num_thresholds: The number of thresholds to use for matching the given
2851
+ specificity.
2852
+ metrics_collections: An optional list of collections that `sensitivity`
2853
+ should be added to.
2854
+ updates_collections: An optional list of collections that `update_op` should
2855
+ be added to.
2856
+ name: An optional variable_scope name.
2857
+
2858
+ Returns:
2859
+ sensitivity: A scalar `Tensor` representing the sensitivity at the given
2860
+ `specificity` value.
2861
+ update_op: An operation that increments the `true_positives`,
2862
+ `true_negatives`, `false_positives` and `false_negatives` variables
2863
+ appropriately and whose value matches `sensitivity`.
2864
+
2865
+ Raises:
2866
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
2867
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2868
+ `specificity` is not between 0 and 1, or if either `metrics_collections`
2869
+ or `updates_collections` are not a list or tuple.
2870
+ RuntimeError: If eager execution is enabled.
2871
+ """
2872
+ if context.executing_eagerly():
2873
+ raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
2874
+ 'supported when eager execution is enabled.')
2875
+
2876
+ if specificity < 0 or specificity > 1:
2877
+ raise ValueError('`specificity` must be in the range [0, 1].')
2878
+
2879
+ with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
2880
+ (predictions, labels, weights)):
2881
+ kepsilon = 1e-7 # to account for floating point imprecisions
2882
+ thresholds = [
2883
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2884
+ ]
2885
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2886
+
2887
+ values, update_ops = _confusion_matrix_at_thresholds(
2888
+ labels, predictions, thresholds, weights)
2889
+
2890
+ def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
2891
+ specificities = math_ops.div(tn, tn + fp + kepsilon)
2892
+ tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
2893
+ tf_index = math_ops.cast(tf_index, dtypes.int32)
2894
+
2895
+ # Now, we have the implicit threshold, so compute the sensitivity:
2896
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
2897
+ name)
2898
+
2899
+ def sensitivity_across_replicas(_, values):
2900
+ return compute_sensitivity_at_specificity(values['tp'], values['tn'],
2901
+ values['fp'], values['fn'],
2902
+ 'value')
2903
+
2904
+ sensitivity = _aggregate_across_replicas(metrics_collections,
2905
+ sensitivity_across_replicas,
2906
+ values)
2907
+
2908
+ update_op = compute_sensitivity_at_specificity(update_ops['tp'],
2909
+ update_ops['tn'],
2910
+ update_ops['fp'],
2911
+ update_ops['fn'],
2912
+ 'update_op')
2913
+ if updates_collections:
2914
+ ops.add_to_collections(updates_collections, update_op)
2915
+
2916
+ return sensitivity, update_op
2917
+
2918
+
2919
+ def _expand_and_tile(tensor, multiple, dim=0, name=None):
2920
+ """Slice `tensor` shape in 2, then tile along the sliced dimension.
2921
+
2922
+ A new dimension is inserted in shape of `tensor` before `dim`, then values are
2923
+ tiled `multiple` times along the new dimension.
2924
+
2925
+ Args:
2926
+ tensor: Input `Tensor` or `SparseTensor`.
2927
+ multiple: Integer, number of times to tile.
2928
+ dim: Integer, dimension along which to tile.
2929
+ name: Name of operation.
2930
+
2931
+ Returns:
2932
+ `Tensor` result of expanding and tiling `tensor`.
2933
+
2934
+ Raises:
2935
+ ValueError: if `multiple` is less than 1, or `dim` is not in
2936
+ `[-rank(tensor), rank(tensor)]`.
2937
+ """
2938
+ if multiple < 1:
2939
+ raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
2940
+ with ops.name_scope(name, 'expand_and_tile',
2941
+ (tensor, multiple, dim)) as scope:
2942
+ # Sparse.
2943
+ tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
2944
+ if isinstance(tensor, sparse_tensor.SparseTensor):
2945
+ if dim < 0:
2946
+ expand_dims = array_ops.reshape(
2947
+ array_ops.size(tensor.dense_shape) + dim, [1])
2948
+ else:
2949
+ expand_dims = [dim]
2950
+ expanded_shape = array_ops.concat(
2951
+ (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
2952
+ array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
2953
+ 0,
2954
+ name='expanded_shape')
2955
+ expanded = sparse_ops.sparse_reshape(
2956
+ tensor, shape=expanded_shape, name='expand')
2957
+ if multiple == 1:
2958
+ return expanded
2959
+ return sparse_ops.sparse_concat(
2960
+ dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
2961
+
2962
+ # Dense.
2963
+ expanded = array_ops.expand_dims(
2964
+ tensor, dim if (dim >= 0) else (dim - 1), name='expand')
2965
+ if multiple == 1:
2966
+ return expanded
2967
+ ones = array_ops.ones_like(array_ops.shape(tensor))
2968
+ tile_multiples = array_ops.concat((ones[:dim], (multiple,), ones[dim:]),
2969
+ 0,
2970
+ name='multiples')
2971
+ return array_ops.tile(expanded, tile_multiples, name=scope)
2972
+
2973
+
2974
+ def _num_relevant(labels, k):
2975
+ """Computes number of relevant values for each row in labels.
2976
+
2977
+ For labels with shape [D1, ... DN, num_labels], this is the minimum of
2978
+ `num_labels` and `k`.
2979
+
2980
+ Args:
2981
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2982
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2983
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2984
+ has shape [batch_size, num_labels].
2985
+ k: Integer, k for @k metric.
2986
+
2987
+ Returns:
2988
+ Integer `Tensor` of shape [D1, ... DN], where each value is the number of
2989
+ relevant values for that row.
2990
+
2991
+ Raises:
2992
+ ValueError: if inputs have invalid dtypes or values.
2993
+ """
2994
+ if k < 1:
2995
+ raise ValueError('Invalid k=%s.' % k)
2996
+ with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
2997
+ # For SparseTensor, calculate separate count for each row.
2998
+ labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
2999
+ if isinstance(labels, sparse_tensor.SparseTensor):
3000
+ return math_ops.minimum(sets.set_size(labels), k, name=scope)
3001
+
3002
+ # The relevant values for each (d1, ... dN) is the minimum of k and the
3003
+ # number of labels along the last dimension that are non-negative.
3004
+ num_labels = math_ops.reduce_sum(
3005
+ array_ops.where_v2(
3006
+ math_ops.greater_equal(labels, 0), array_ops.ones_like(labels),
3007
+ array_ops.zeros_like(labels)),
3008
+ axis=-1)
3009
+ return math_ops.minimum(num_labels, k, name=scope)
3010
+
3011
+
3012
+ def _sparse_average_precision_at_top_k(labels, predictions_idx):
3013
+ """Computes average precision@k of predictions with respect to sparse labels.
3014
+
3015
+ From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3016
+ for each row is:
3017
+
3018
+ AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3019
+
3020
+ A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3021
+ `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3022
+ Each row of the results contains the average precision for that row.
3023
+
3024
+ Args:
3025
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3026
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3027
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3028
+ the associated prediction. Commonly, N=1 and `labels` has shape
3029
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3030
+ Values should be non-negative. Negative values are ignored.
3031
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3032
+ Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3033
+ dimension must be set and contains the top `k` predicted class indices.
3034
+ [D1, ... DN] must match `labels`. Values should be in range
3035
+ [0, num_classes).
3036
+
3037
+ Returns:
3038
+ `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3039
+ precision for that row.
3040
+
3041
+ Raises:
3042
+ ValueError: if the last dimension of predictions_idx is not set.
3043
+ """
3044
+ with ops.name_scope(None, 'average_precision',
3045
+ (predictions_idx, labels)) as scope:
3046
+ predictions_idx = math_ops.cast(
3047
+ predictions_idx, dtypes.int64, name='predictions_idx')
3048
+ if predictions_idx.get_shape().ndims == 0:
3049
+ raise ValueError('The rank of predictions_idx must be at least 1.')
3050
+ k = predictions_idx.get_shape().as_list()[-1]
3051
+ if k is None:
3052
+ raise ValueError('The last dimension of predictions_idx must be set.')
3053
+ labels = _maybe_expand_labels(labels, predictions_idx)
3054
+
3055
+ # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3056
+ # prediction for each k, so we can calculate separate true positive values
3057
+ # for each k.
3058
+ predictions_idx_per_k = array_ops.expand_dims(
3059
+ predictions_idx, -1, name='predictions_idx_per_k')
3060
+
3061
+ # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3062
+ labels_per_k = _expand_and_tile(
3063
+ labels, multiple=k, dim=-1, name='labels_per_k')
3064
+
3065
+ # The following tensors are all of shape [D1, ... DN, k], containing values
3066
+ # per row, per k value.
3067
+ # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3068
+ # that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3069
+ # the formula above.
3070
+ # `tp_per_k` (int32) - True positive counts.
3071
+ # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3072
+ # the precision denominator.
3073
+ # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3074
+ # term from the formula above.
3075
+ # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3076
+ # precisions at all k for which relevance indicator is true.
3077
+ relevant_per_k = _sparse_true_positive_at_k(
3078
+ labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3079
+ tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3080
+ retrieved_per_k = math_ops.cumsum(
3081
+ array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3082
+ precision_per_k = math_ops.div(
3083
+ math_ops.cast(tp_per_k, dtypes.float64),
3084
+ math_ops.cast(retrieved_per_k, dtypes.float64),
3085
+ name='precision_per_k')
3086
+ relevant_precision_per_k = math_ops.multiply(
3087
+ precision_per_k,
3088
+ math_ops.cast(relevant_per_k, dtypes.float64),
3089
+ name='relevant_precision_per_k')
3090
+
3091
+ # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3092
+ precision_sum = math_ops.reduce_sum(
3093
+ relevant_precision_per_k, axis=(-1,), name='precision_sum')
3094
+
3095
+ # Divide by number of relevant items to get average precision. These are
3096
+ # the "num_relevant_items" and "AveP" terms from the formula above.
3097
+ num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3098
+ return math_ops.div(precision_sum, num_relevant_items, name=scope)
3099
+
3100
+
3101
+ def _streaming_sparse_average_precision_at_top_k(labels,
3102
+ predictions_idx,
3103
+ weights=None,
3104
+ metrics_collections=None,
3105
+ updates_collections=None,
3106
+ name=None):
3107
+ """Computes average precision@k of predictions with respect to sparse labels.
3108
+
3109
+ `sparse_average_precision_at_top_k` creates two local variables,
3110
+ `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3111
+ are used to compute the frequency. This frequency is ultimately returned as
3112
+ `average_precision_at_<k>`: an idempotent operation that simply divides
3113
+ `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3114
+
3115
+ For estimation of the metric over a stream of data, the function creates an
3116
+ `update_op` operation that updates these variables and returns the
3117
+ `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3118
+ the true positives and false positives weighted by `weights`. Then `update_op`
3119
+ increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3120
+ values.
3121
+
3122
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3123
+
3124
+ Args:
3125
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3126
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3127
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3128
+ the associated prediction. Commonly, N=1 and `labels` has shape
3129
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3130
+ Values should be non-negative. Negative values are ignored.
3131
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3132
+ Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3133
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
3134
+ match `labels`. Values should be in range [0, num_classes).
3135
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3136
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3137
+ dimensions must be either `1`, or the same as the corresponding `labels`
3138
+ dimension).
3139
+ metrics_collections: An optional list of collections that values should
3140
+ be added to.
3141
+ updates_collections: An optional list of collections that updates should
3142
+ be added to.
3143
+ name: Name of new update operation, and namespace for other dependent ops.
3144
+
3145
+ Returns:
3146
+ mean_average_precision: Scalar `float64` `Tensor` with the mean average
3147
+ precision values.
3148
+ update: `Operation` that increments variables appropriately, and whose
3149
+ value matches `metric`.
3150
+ """
3151
+ with ops.name_scope(name, 'average_precision_at_top_k',
3152
+ (predictions_idx, labels, weights)) as scope:
3153
+ # Calculate per-example average precision, and apply weights.
3154
+ average_precision = _sparse_average_precision_at_top_k(
3155
+ predictions_idx=predictions_idx, labels=labels)
3156
+ if weights is not None:
3157
+ weights = weights_broadcast_ops.broadcast_weights(
3158
+ math_ops.cast(weights, dtypes.float64), average_precision)
3159
+ average_precision = math_ops.multiply(average_precision, weights)
3160
+
3161
+ # Create accumulation variables and update ops for max average precision and
3162
+ # total average precision.
3163
+ with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3164
+ # `max` is the max possible precision. Since max for any row is 1.0:
3165
+ # - For the unweighted case, this is just the number of rows.
3166
+ # - For the weighted case, it's the sum of the weights broadcast across
3167
+ # `average_precision` rows.
3168
+ max_var = metric_variable([], dtypes.float64, name=max_scope)
3169
+ if weights is None:
3170
+ batch_max = math_ops.cast(
3171
+ array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3172
+ else:
3173
+ batch_max = math_ops.reduce_sum(weights, name='batch_max')
3174
+ max_update = state_ops.assign_add(
3175
+ max_var, batch_max, name='update', use_locking=True)
3176
+ with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3177
+ total_var = metric_variable([], dtypes.float64, name=total_scope)
3178
+ batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3179
+ total_update = state_ops.assign_add(
3180
+ total_var, batch_total, name='update', use_locking=True)
3181
+
3182
+ # Divide total by max to get mean, for both vars and the update ops.
3183
+ def precision_across_replicas(_, total_var, max_var):
3184
+ return _safe_scalar_div(total_var, max_var, name='mean')
3185
+
3186
+ mean_average_precision = _aggregate_across_replicas(
3187
+ metrics_collections, precision_across_replicas, total_var, max_var)
3188
+
3189
+ update = _safe_scalar_div(total_update, max_update, name=scope)
3190
+ if updates_collections:
3191
+ ops.add_to_collections(updates_collections, update)
3192
+
3193
+ return mean_average_precision, update
3194
+
3195
+
3196
+ def _clean_out_of_range_indices(labels, num_classes):
3197
+ """Replaces large out-of-range labels by small out-of-range labels.
3198
+
3199
+ Replaces any value in `labels` that is greater or equal to `num_classes` by
3200
+ -1. Do this conditionally for efficiency in case there are no such values.
3201
+
3202
+ Args:
3203
+ labels: `int64` `Tensor` or `SparseTensor`.
3204
+ num_classes: `int64` scalar `Tensor`.
3205
+
3206
+ Returns:
3207
+ An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater
3208
+ or equal to num_classes replaced by -1.
3209
+ """
3210
+
3211
+ def _labels_is_sparse():
3212
+ """Returns true is `labels` is a sparse tensor."""
3213
+ return isinstance(
3214
+ labels, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))
3215
+
3216
+ def _clean_out_of_range(values):
3217
+ """Replaces by -1 any large out-of-range `values`."""
3218
+ return array_ops.where_v2(
3219
+ math_ops.greater_equal(values, num_classes),
3220
+ -1 * array_ops.ones_like(values), values)
3221
+
3222
+ def _clean_labels_out_of_range():
3223
+ """Replaces by -1 ane large out-of-range values in `labels`."""
3224
+ if _labels_is_sparse():
3225
+ return type(labels)(
3226
+ indices=labels.indices,
3227
+ values=_clean_out_of_range(labels.values),
3228
+ dense_shape=labels.dense_shape)
3229
+ else:
3230
+ return _clean_out_of_range(labels)
3231
+
3232
+ max_labels = math_ops.reduce_max(
3233
+ labels.values if _labels_is_sparse() else labels)
3234
+ return control_flow_ops.cond(
3235
+ math_ops.greater_equal(max_labels, num_classes),
3236
+ _clean_labels_out_of_range, lambda: labels)
3237
+
3238
+
3239
+ @tf_export(v1=['metrics.sparse_average_precision_at_k'])
3240
+ @deprecated(None, 'Use average_precision_at_k instead')
3241
+ def sparse_average_precision_at_k(labels,
3242
+ predictions,
3243
+ k,
3244
+ weights=None,
3245
+ metrics_collections=None,
3246
+ updates_collections=None,
3247
+ name=None):
3248
+ """Renamed to `average_precision_at_k`, please use that method instead."""
3249
+ return average_precision_at_k(
3250
+ labels=labels,
3251
+ predictions=predictions,
3252
+ k=k,
3253
+ weights=weights,
3254
+ metrics_collections=metrics_collections,
3255
+ updates_collections=updates_collections,
3256
+ name=name)
3257
+
3258
+
3259
+ @tf_export(v1=['metrics.average_precision_at_k'])
3260
+ def average_precision_at_k(labels,
3261
+ predictions,
3262
+ k,
3263
+ weights=None,
3264
+ metrics_collections=None,
3265
+ updates_collections=None,
3266
+ name=None):
3267
+ """Computes average precision@k of predictions with respect to sparse labels.
3268
+
3269
+ `average_precision_at_k` creates two local variables,
3270
+ `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3271
+ are used to compute the frequency. This frequency is ultimately returned as
3272
+ `average_precision_at_<k>`: an idempotent operation that simply divides
3273
+ `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3274
+
3275
+ For estimation of the metric over a stream of data, the function creates an
3276
+ `update_op` operation that updates these variables and returns the
3277
+ `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3278
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
3279
+ `labels` calculate the true positives and false positives weighted by
3280
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
3281
+ `false_positive_at_<k>` using these values.
3282
+
3283
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3284
+
3285
+ Args:
3286
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3287
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3288
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3289
+ the associated prediction. Commonly, N=1 and `labels` has shape
3290
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3291
+ should be in range [0, num_classes), where num_classes is the last
3292
+ dimension of `predictions`. Values outside this range are ignored.
3293
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3294
+ N >= 1. Commonly, N=1 and `predictions` has shape
3295
+ [batch size, num_classes]. The final dimension contains the logit values
3296
+ for each class. [D1, ... DN] must match `labels`.
3297
+ k: Integer, k for @k metric. This will calculate an average precision for
3298
+ range `[1,k]`, as documented above.
3299
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3300
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3301
+ dimensions must be either `1`, or the same as the corresponding `labels`
3302
+ dimension).
3303
+ metrics_collections: An optional list of collections that values should
3304
+ be added to.
3305
+ updates_collections: An optional list of collections that updates should
3306
+ be added to.
3307
+ name: Name of new update operation, and namespace for other dependent ops.
3308
+
3309
+ Returns:
3310
+ mean_average_precision: Scalar `float64` `Tensor` with the mean average
3311
+ precision values.
3312
+ update: `Operation` that increments variables appropriately, and whose
3313
+ value matches `metric`.
3314
+
3315
+ Raises:
3316
+ ValueError: if k is invalid.
3317
+ RuntimeError: If eager execution is enabled.
3318
+ """
3319
+ if context.executing_eagerly():
3320
+ raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3321
+ 'supported when eager execution is enabled.')
3322
+
3323
+ if k < 1:
3324
+ raise ValueError('Invalid k=%s.' % k)
3325
+ with ops.name_scope(name, _at_k_name('average_precision', k),
3326
+ (predictions, labels, weights)) as scope:
3327
+ # Calculate top k indices to produce [D1, ... DN, k] tensor.
3328
+ _, predictions_idx = nn.top_k(predictions, k)
3329
+ # The documentation states that labels should be in [0, ..., num_classes),
3330
+ # but num_classes is lost when predictions_idx replaces predictions.
3331
+ # For conformity with the documentation, any label >= num_classes, which is
3332
+ # ignored, is replaced by -1.
3333
+ labels = _clean_out_of_range_indices(
3334
+ labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64))
3335
+ return _streaming_sparse_average_precision_at_top_k(
3336
+ labels=labels,
3337
+ predictions_idx=predictions_idx,
3338
+ weights=weights,
3339
+ metrics_collections=metrics_collections,
3340
+ updates_collections=updates_collections,
3341
+ name=scope)
3342
+
3343
+
3344
+ def _sparse_false_positive_at_k(labels,
3345
+ predictions_idx,
3346
+ class_id=None,
3347
+ weights=None):
3348
+ """Calculates false positives for precision@k.
3349
+
3350
+ If `class_id` is specified, calculate binary true positives for `class_id`
3351
+ only.
3352
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
3353
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3354
+
3355
+ Args:
3356
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3357
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3358
+ target classes for the associated prediction. Commonly, N=1 and `labels`
3359
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
3360
+ `predictions_idx`.
3361
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3362
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3363
+ match `labels`.
3364
+ class_id: Class for which we want binary metrics.
3365
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3366
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3367
+ dimensions must be either `1`, or the same as the corresponding `labels`
3368
+ dimension).
3369
+
3370
+ Returns:
3371
+ A [D1, ... DN] `Tensor` of false positive counts.
3372
+ """
3373
+ with ops.name_scope(None, 'false_positives',
3374
+ (predictions_idx, labels, weights)):
3375
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3376
+ class_id)
3377
+ fp = sets.set_size(
3378
+ sets.set_difference(predictions_idx, labels, aminusb=True))
3379
+ fp = math_ops.cast(fp, dtypes.float64)
3380
+ if weights is not None:
3381
+ with ops.control_dependencies(
3382
+ (weights_broadcast_ops.assert_broadcastable(weights, fp),)):
3383
+ weights = math_ops.cast(weights, dtypes.float64)
3384
+ fp = math_ops.multiply(fp, weights)
3385
+ return fp
3386
+
3387
+
3388
+ def _streaming_sparse_false_positive_at_k(labels,
3389
+ predictions_idx,
3390
+ k=None,
3391
+ class_id=None,
3392
+ weights=None,
3393
+ name=None):
3394
+ """Calculates weighted per step false positives for precision@k.
3395
+
3396
+ If `class_id` is specified, calculate binary true positives for `class_id`
3397
+ only.
3398
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
3399
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
3400
+
3401
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3402
+
3403
+ Args:
3404
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3405
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3406
+ target classes for the associated prediction. Commonly, N=1 and `labels`
3407
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
3408
+ `predictions_idx`.
3409
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3410
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3411
+ match `labels`.
3412
+ k: Integer, k for @k metric. This is only used for default op name.
3413
+ class_id: Class for which we want binary metrics.
3414
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3415
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3416
+ dimensions must be either `1`, or the same as the corresponding `labels`
3417
+ dimension).
3418
+ name: Name of new variable, and namespace for other dependent ops.
3419
+
3420
+ Returns:
3421
+ A tuple of `Variable` and update `Operation`.
3422
+
3423
+ Raises:
3424
+ ValueError: If `weights` is not `None` and has an incompatible shape.
3425
+ """
3426
+ with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3427
+ (predictions_idx, labels, weights)) as scope:
3428
+ fp = _sparse_false_positive_at_k(
3429
+ predictions_idx=predictions_idx,
3430
+ labels=labels,
3431
+ class_id=class_id,
3432
+ weights=weights)
3433
+ batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3434
+
3435
+ var = metric_variable([], dtypes.float64, name=scope)
3436
+ return var, state_ops.assign_add(
3437
+ var, batch_total_fp, name='update', use_locking=True)
3438
+
3439
+
3440
+ @tf_export(v1=['metrics.precision_at_top_k'])
3441
+ def precision_at_top_k(labels,
3442
+ predictions_idx,
3443
+ k=None,
3444
+ class_id=None,
3445
+ weights=None,
3446
+ metrics_collections=None,
3447
+ updates_collections=None,
3448
+ name=None):
3449
+ """Computes precision@k of the predictions with respect to sparse labels.
3450
+
3451
+ Differs from `sparse_precision_at_k` in that predictions must be in the form
3452
+ of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3453
+ Refer to `sparse_precision_at_k` for more details.
3454
+
3455
+ Args:
3456
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3457
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3458
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3459
+ the associated prediction. Commonly, N=1 and `labels` has shape
3460
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3461
+ should be in range [0, num_classes), where num_classes is the last
3462
+ dimension of `predictions`. Values outside this range are ignored.
3463
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3464
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3465
+ The final dimension contains the top `k` predicted class indices.
3466
+ [D1, ... DN] must match `labels`.
3467
+ k: Integer, k for @k metric. Only used for the default op name.
3468
+ class_id: Integer class ID for which we want binary metrics. This should be
3469
+ in range [0, num_classes], where num_classes is the last dimension of
3470
+ `predictions`. If `class_id` is outside this range, the method returns
3471
+ NAN.
3472
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3473
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3474
+ dimensions must be either `1`, or the same as the corresponding `labels`
3475
+ dimension).
3476
+ metrics_collections: An optional list of collections that values should
3477
+ be added to.
3478
+ updates_collections: An optional list of collections that updates should
3479
+ be added to.
3480
+ name: Name of new update operation, and namespace for other dependent ops.
3481
+
3482
+ Returns:
3483
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
3484
+ divided by the sum of `true_positives` and `false_positives`.
3485
+ update_op: `Operation` that increments `true_positives` and
3486
+ `false_positives` variables appropriately, and whose value matches
3487
+ `precision`.
3488
+
3489
+ Raises:
3490
+ ValueError: If `weights` is not `None` and its shape doesn't match
3491
+ `predictions`, or if either `metrics_collections` or `updates_collections`
3492
+ are not a list or tuple.
3493
+ RuntimeError: If eager execution is enabled.
3494
+ """
3495
+ if context.executing_eagerly():
3496
+ raise RuntimeError('tf.metrics.precision_at_top_k is not '
3497
+ 'supported when eager execution is enabled.')
3498
+
3499
+ with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3500
+ (predictions_idx, labels, weights)) as scope:
3501
+ labels = _maybe_expand_labels(labels, predictions_idx)
3502
+ top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3503
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
3504
+ predictions_idx=top_k_idx,
3505
+ labels=labels,
3506
+ k=k,
3507
+ class_id=class_id,
3508
+ weights=weights)
3509
+ fp, fp_update = _streaming_sparse_false_positive_at_k(
3510
+ predictions_idx=top_k_idx,
3511
+ labels=labels,
3512
+ k=k,
3513
+ class_id=class_id,
3514
+ weights=weights)
3515
+
3516
+ def precision_across_replicas(_, tp, fp):
3517
+ return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
3518
+
3519
+ metric = _aggregate_across_replicas(metrics_collections,
3520
+ precision_across_replicas, tp, fp)
3521
+
3522
+ update = math_ops.div(
3523
+ tp_update, math_ops.add(tp_update, fp_update), name='update')
3524
+ if updates_collections:
3525
+ ops.add_to_collections(updates_collections, update)
3526
+ return metric, update
3527
+
3528
+
3529
+ @tf_export(v1=['metrics.sparse_precision_at_k'])
3530
+ @deprecated(None, 'Use precision_at_k instead')
3531
+ def sparse_precision_at_k(labels,
3532
+ predictions,
3533
+ k,
3534
+ class_id=None,
3535
+ weights=None,
3536
+ metrics_collections=None,
3537
+ updates_collections=None,
3538
+ name=None):
3539
+ """Renamed to `precision_at_k`, please use that method instead."""
3540
+ return precision_at_k(
3541
+ labels=labels,
3542
+ predictions=predictions,
3543
+ k=k,
3544
+ class_id=class_id,
3545
+ weights=weights,
3546
+ metrics_collections=metrics_collections,
3547
+ updates_collections=updates_collections,
3548
+ name=name)
3549
+
3550
+
3551
+ @tf_export(v1=['metrics.precision_at_k'])
3552
+ def precision_at_k(labels,
3553
+ predictions,
3554
+ k,
3555
+ class_id=None,
3556
+ weights=None,
3557
+ metrics_collections=None,
3558
+ updates_collections=None,
3559
+ name=None):
3560
+ """Computes precision@k of the predictions with respect to sparse labels.
3561
+
3562
+ If `class_id` is specified, we calculate precision by considering only the
3563
+ entries in the batch for which `class_id` is in the top-k highest
3564
+ `predictions`, and computing the fraction of them for which `class_id` is
3565
+ indeed a correct label.
3566
+ If `class_id` is not specified, we'll calculate precision as how often on
3567
+ average a class among the top-k classes with the highest predicted values
3568
+ of a batch entry is correct and can be found in the label for that entry.
3569
+
3570
+ `precision_at_k` creates two local variables,
3571
+ `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3572
+ the precision@k frequency. This frequency is ultimately returned as
3573
+ `precision_at_<k>`: an idempotent operation that simply divides
3574
+ `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3575
+ `false_positive_at_<k>`).
3576
+
3577
+ For estimation of the metric over a stream of data, the function creates an
3578
+ `update_op` operation that updates these variables and returns the
3579
+ `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3580
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
3581
+ `labels` calculate the true positives and false positives weighted by
3582
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
3583
+ `false_positive_at_<k>` using these values.
3584
+
3585
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3586
+
3587
+ Args:
3588
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3589
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3590
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3591
+ the associated prediction. Commonly, N=1 and `labels` has shape
3592
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3593
+ should be in range [0, num_classes), where num_classes is the last
3594
+ dimension of `predictions`. Values outside this range are ignored.
3595
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3596
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3597
+ The final dimension contains the logit values for each class. [D1, ... DN]
3598
+ must match `labels`.
3599
+ k: Integer, k for @k metric.
3600
+ class_id: Integer class ID for which we want binary metrics. This should be
3601
+ in range [0, num_classes], where num_classes is the last dimension of
3602
+ `predictions`. If `class_id` is outside this range, the method returns
3603
+ NAN.
3604
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3605
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3606
+ dimensions must be either `1`, or the same as the corresponding `labels`
3607
+ dimension).
3608
+ metrics_collections: An optional list of collections that values should
3609
+ be added to.
3610
+ updates_collections: An optional list of collections that updates should
3611
+ be added to.
3612
+ name: Name of new update operation, and namespace for other dependent ops.
3613
+
3614
+ Returns:
3615
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
3616
+ divided by the sum of `true_positives` and `false_positives`.
3617
+ update_op: `Operation` that increments `true_positives` and
3618
+ `false_positives` variables appropriately, and whose value matches
3619
+ `precision`.
3620
+
3621
+ Raises:
3622
+ ValueError: If `weights` is not `None` and its shape doesn't match
3623
+ `predictions`, or if either `metrics_collections` or `updates_collections`
3624
+ are not a list or tuple.
3625
+ RuntimeError: If eager execution is enabled.
3626
+ """
3627
+ if context.executing_eagerly():
3628
+ raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3629
+ 'supported when eager execution is enabled.')
3630
+
3631
+ with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3632
+ (predictions, labels, weights)) as scope:
3633
+ _, top_k_idx = nn.top_k(predictions, k)
3634
+ return precision_at_top_k(
3635
+ labels=labels,
3636
+ predictions_idx=top_k_idx,
3637
+ k=k,
3638
+ class_id=class_id,
3639
+ weights=weights,
3640
+ metrics_collections=metrics_collections,
3641
+ updates_collections=updates_collections,
3642
+ name=scope)
3643
+
3644
+
3645
+ @tf_export(v1=['metrics.specificity_at_sensitivity'])
3646
+ def specificity_at_sensitivity(labels,
3647
+ predictions,
3648
+ sensitivity,
3649
+ weights=None,
3650
+ num_thresholds=200,
3651
+ metrics_collections=None,
3652
+ updates_collections=None,
3653
+ name=None):
3654
+ """Computes the specificity at a given sensitivity.
3655
+
3656
+ The `specificity_at_sensitivity` function creates four local
3657
+ variables, `true_positives`, `true_negatives`, `false_positives` and
3658
+ `false_negatives` that are used to compute the specificity at the given
3659
+ sensitivity value. The threshold for the given sensitivity value is computed
3660
+ and used to evaluate the corresponding specificity.
3661
+
3662
+ For estimation of the metric over a stream of data, the function creates an
3663
+ `update_op` operation that updates these variables and returns the
3664
+ `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3665
+ `false_positives` and `false_negatives` counts with the weight of each case
3666
+ found in the `predictions` and `labels`.
3667
+
3668
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3669
+
3670
+ For additional information about specificity and sensitivity, see the
3671
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3672
+
3673
+ Args:
3674
+ labels: The ground truth values, a `Tensor` whose dimensions must match
3675
+ `predictions`. Will be cast to `bool`.
3676
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
3677
+ are in the range `[0, 1]`.
3678
+ sensitivity: A scalar value in range `[0, 1]`.
3679
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
3680
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3681
+ be either `1`, or the same as the corresponding `labels` dimension).
3682
+ num_thresholds: The number of thresholds to use for matching the given
3683
+ sensitivity.
3684
+ metrics_collections: An optional list of collections that `specificity`
3685
+ should be added to.
3686
+ updates_collections: An optional list of collections that `update_op` should
3687
+ be added to.
3688
+ name: An optional variable_scope name.
3689
+
3690
+ Returns:
3691
+ specificity: A scalar `Tensor` representing the specificity at the given
3692
+ `sensitivity` value.
3693
+ update_op: An operation that increments the `true_positives`,
3694
+ `true_negatives`, `false_positives` and `false_negatives` variables
3695
+ appropriately and whose value matches `specificity`.
3696
+
3697
+ Raises:
3698
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
3699
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
3700
+ `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3701
+ or `updates_collections` are not a list or tuple.
3702
+ RuntimeError: If eager execution is enabled.
3703
+ """
3704
+ if context.executing_eagerly():
3705
+ raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3706
+ 'supported when eager execution is enabled.')
3707
+
3708
+ if sensitivity < 0 or sensitivity > 1:
3709
+ raise ValueError('`sensitivity` must be in the range [0, 1].')
3710
+
3711
+ with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3712
+ (predictions, labels, weights)):
3713
+ kepsilon = 1e-7 # to account for floating point imprecisions
3714
+ thresholds = [
3715
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3716
+ ]
3717
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3718
+
3719
+ values, update_ops = _confusion_matrix_at_thresholds(
3720
+ labels, predictions, thresholds, weights)
3721
+
3722
+ def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3723
+ """Computes the specificity at the given sensitivity.
3724
+
3725
+ Args:
3726
+ tp: True positives.
3727
+ tn: True negatives.
3728
+ fp: False positives.
3729
+ fn: False negatives.
3730
+ name: The name of the operation.
3731
+
3732
+ Returns:
3733
+ The specificity using the aggregated values.
3734
+ """
3735
+ sensitivities = math_ops.div(tp, tp + fn + kepsilon)
3736
+
3737
+ # We'll need to use this trick until tf.argmax allows us to specify
3738
+ # whether we should use the first or last index in case of ties.
3739
+ min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3740
+ indices_at_minval = math_ops.equal(
3741
+ math_ops.abs(sensitivities - sensitivity), min_val)
3742
+ indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3743
+ indices_at_minval = math_ops.cumsum(indices_at_minval)
3744
+ tf_index = math_ops.argmax(indices_at_minval, 0)
3745
+ tf_index = math_ops.cast(tf_index, dtypes.int32)
3746
+
3747
+ # Now, we have the implicit threshold, so compute the specificity:
3748
+ return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
3749
+ name)
3750
+
3751
+ def specificity_across_replicas(_, values):
3752
+ return compute_specificity_at_sensitivity(values['tp'], values['tn'],
3753
+ values['fp'], values['fn'],
3754
+ 'value')
3755
+
3756
+ specificity = _aggregate_across_replicas(metrics_collections,
3757
+ specificity_across_replicas,
3758
+ values)
3759
+
3760
+ update_op = compute_specificity_at_sensitivity(update_ops['tp'],
3761
+ update_ops['tn'],
3762
+ update_ops['fp'],
3763
+ update_ops['fn'],
3764
+ 'update_op')
3765
+ if updates_collections:
3766
+ ops.add_to_collections(updates_collections, update_op)
3767
+
3768
+ return specificity, update_op