easy-cs-rec-custommodel 0.8.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of easy-cs-rec-custommodel might be problematic. Click here for more details.

Files changed (336) hide show
  1. easy_cs_rec_custommodel-0.8.6.dist-info/LICENSE +203 -0
  2. easy_cs_rec_custommodel-0.8.6.dist-info/METADATA +48 -0
  3. easy_cs_rec_custommodel-0.8.6.dist-info/RECORD +336 -0
  4. easy_cs_rec_custommodel-0.8.6.dist-info/WHEEL +6 -0
  5. easy_cs_rec_custommodel-0.8.6.dist-info/top_level.txt +2 -0
  6. easy_rec/__init__.py +114 -0
  7. easy_rec/python/__init__.py +0 -0
  8. easy_rec/python/builders/__init__.py +0 -0
  9. easy_rec/python/builders/hyperparams_builder.py +78 -0
  10. easy_rec/python/builders/loss_builder.py +333 -0
  11. easy_rec/python/builders/optimizer_builder.py +211 -0
  12. easy_rec/python/builders/strategy_builder.py +44 -0
  13. easy_rec/python/compat/__init__.py +0 -0
  14. easy_rec/python/compat/adam_s.py +245 -0
  15. easy_rec/python/compat/array_ops.py +229 -0
  16. easy_rec/python/compat/dynamic_variable.py +542 -0
  17. easy_rec/python/compat/early_stopping.py +653 -0
  18. easy_rec/python/compat/embedding_ops.py +162 -0
  19. easy_rec/python/compat/embedding_parallel_saver.py +316 -0
  20. easy_rec/python/compat/estimator_train.py +116 -0
  21. easy_rec/python/compat/exporter.py +473 -0
  22. easy_rec/python/compat/feature_column/__init__.py +0 -0
  23. easy_rec/python/compat/feature_column/feature_column.py +3675 -0
  24. easy_rec/python/compat/feature_column/feature_column_v2.py +5233 -0
  25. easy_rec/python/compat/feature_column/sequence_feature_column.py +648 -0
  26. easy_rec/python/compat/feature_column/utils.py +154 -0
  27. easy_rec/python/compat/layers.py +329 -0
  28. easy_rec/python/compat/ops.py +14 -0
  29. easy_rec/python/compat/optimizers.py +619 -0
  30. easy_rec/python/compat/queues.py +311 -0
  31. easy_rec/python/compat/regularizers.py +208 -0
  32. easy_rec/python/compat/sok_optimizer.py +440 -0
  33. easy_rec/python/compat/sync_replicas_optimizer.py +528 -0
  34. easy_rec/python/compat/weight_decay_optimizers.py +475 -0
  35. easy_rec/python/core/__init__.py +0 -0
  36. easy_rec/python/core/easyrec_metrics/__init__.py +24 -0
  37. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py +3702 -0
  38. easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_tf.py +3768 -0
  39. easy_rec/python/core/learning_schedules.py +228 -0
  40. easy_rec/python/core/metrics.py +402 -0
  41. easy_rec/python/core/sampler.py +844 -0
  42. easy_rec/python/eval.py +102 -0
  43. easy_rec/python/export.py +150 -0
  44. easy_rec/python/feature_column/__init__.py +0 -0
  45. easy_rec/python/feature_column/feature_column.py +664 -0
  46. easy_rec/python/feature_column/feature_group.py +89 -0
  47. easy_rec/python/hpo/__init__.py +0 -0
  48. easy_rec/python/hpo/emr_hpo.py +140 -0
  49. easy_rec/python/hpo/generate_hpo_sql.py +71 -0
  50. easy_rec/python/hpo/pai_hpo.py +297 -0
  51. easy_rec/python/inference/__init__.py +0 -0
  52. easy_rec/python/inference/csv_predictor.py +189 -0
  53. easy_rec/python/inference/hive_parquet_predictor.py +200 -0
  54. easy_rec/python/inference/hive_predictor.py +166 -0
  55. easy_rec/python/inference/odps_predictor.py +70 -0
  56. easy_rec/python/inference/parquet_predictor.py +147 -0
  57. easy_rec/python/inference/parquet_predictor_v2.py +147 -0
  58. easy_rec/python/inference/predictor.py +621 -0
  59. easy_rec/python/inference/processor/__init__.py +0 -0
  60. easy_rec/python/inference/processor/test.py +170 -0
  61. easy_rec/python/inference/vector_retrieve.py +124 -0
  62. easy_rec/python/input/__init__.py +0 -0
  63. easy_rec/python/input/batch_tfrecord_input.py +117 -0
  64. easy_rec/python/input/criteo_binary_reader.py +259 -0
  65. easy_rec/python/input/criteo_input.py +107 -0
  66. easy_rec/python/input/csv_input.py +175 -0
  67. easy_rec/python/input/csv_input_ex.py +72 -0
  68. easy_rec/python/input/csv_input_v2.py +68 -0
  69. easy_rec/python/input/datahub_input.py +320 -0
  70. easy_rec/python/input/dummy_input.py +58 -0
  71. easy_rec/python/input/hive_input.py +123 -0
  72. easy_rec/python/input/hive_parquet_input.py +140 -0
  73. easy_rec/python/input/hive_rtp_input.py +174 -0
  74. easy_rec/python/input/input.py +1064 -0
  75. easy_rec/python/input/kafka_dataset.py +144 -0
  76. easy_rec/python/input/kafka_input.py +235 -0
  77. easy_rec/python/input/load_parquet.py +317 -0
  78. easy_rec/python/input/odps_input.py +101 -0
  79. easy_rec/python/input/odps_input_v2.py +110 -0
  80. easy_rec/python/input/odps_input_v3.py +132 -0
  81. easy_rec/python/input/odps_rtp_input.py +187 -0
  82. easy_rec/python/input/odps_rtp_input_v2.py +104 -0
  83. easy_rec/python/input/parquet_input.py +397 -0
  84. easy_rec/python/input/parquet_input_v2.py +180 -0
  85. easy_rec/python/input/parquet_input_v3.py +203 -0
  86. easy_rec/python/input/rtp_input.py +225 -0
  87. easy_rec/python/input/rtp_input_v2.py +145 -0
  88. easy_rec/python/input/tfrecord_input.py +100 -0
  89. easy_rec/python/layers/__init__.py +0 -0
  90. easy_rec/python/layers/backbone.py +571 -0
  91. easy_rec/python/layers/capsule_layer.py +176 -0
  92. easy_rec/python/layers/cmbf.py +390 -0
  93. easy_rec/python/layers/common_layers.py +192 -0
  94. easy_rec/python/layers/dnn.py +87 -0
  95. easy_rec/python/layers/embed_input_layer.py +25 -0
  96. easy_rec/python/layers/fm.py +26 -0
  97. easy_rec/python/layers/input_layer.py +396 -0
  98. easy_rec/python/layers/keras/__init__.py +34 -0
  99. easy_rec/python/layers/keras/activation.py +114 -0
  100. easy_rec/python/layers/keras/attention.py +267 -0
  101. easy_rec/python/layers/keras/auxiliary_loss.py +47 -0
  102. easy_rec/python/layers/keras/blocks.py +262 -0
  103. easy_rec/python/layers/keras/bst.py +119 -0
  104. easy_rec/python/layers/keras/custom_ops.py +250 -0
  105. easy_rec/python/layers/keras/data_augment.py +133 -0
  106. easy_rec/python/layers/keras/din.py +67 -0
  107. easy_rec/python/layers/keras/einsum_dense.py +598 -0
  108. easy_rec/python/layers/keras/embedding.py +81 -0
  109. easy_rec/python/layers/keras/fibinet.py +251 -0
  110. easy_rec/python/layers/keras/interaction.py +416 -0
  111. easy_rec/python/layers/keras/layer_norm.py +364 -0
  112. easy_rec/python/layers/keras/mask_net.py +166 -0
  113. easy_rec/python/layers/keras/multi_head_attention.py +717 -0
  114. easy_rec/python/layers/keras/multi_task.py +125 -0
  115. easy_rec/python/layers/keras/numerical_embedding.py +376 -0
  116. easy_rec/python/layers/keras/ppnet.py +194 -0
  117. easy_rec/python/layers/keras/transformer.py +192 -0
  118. easy_rec/python/layers/layer_norm.py +51 -0
  119. easy_rec/python/layers/mmoe.py +83 -0
  120. easy_rec/python/layers/multihead_attention.py +162 -0
  121. easy_rec/python/layers/multihead_cross_attention.py +749 -0
  122. easy_rec/python/layers/senet.py +73 -0
  123. easy_rec/python/layers/seq_input_layer.py +134 -0
  124. easy_rec/python/layers/sequence_feature_layer.py +249 -0
  125. easy_rec/python/layers/uniter.py +301 -0
  126. easy_rec/python/layers/utils.py +248 -0
  127. easy_rec/python/layers/variational_dropout_layer.py +130 -0
  128. easy_rec/python/loss/__init__.py +0 -0
  129. easy_rec/python/loss/circle_loss.py +82 -0
  130. easy_rec/python/loss/contrastive_loss.py +79 -0
  131. easy_rec/python/loss/f1_reweight_loss.py +38 -0
  132. easy_rec/python/loss/focal_loss.py +93 -0
  133. easy_rec/python/loss/jrc_loss.py +128 -0
  134. easy_rec/python/loss/listwise_loss.py +161 -0
  135. easy_rec/python/loss/multi_similarity.py +68 -0
  136. easy_rec/python/loss/pairwise_loss.py +307 -0
  137. easy_rec/python/loss/softmax_loss_with_negative_mining.py +110 -0
  138. easy_rec/python/loss/zero_inflated_lognormal.py +76 -0
  139. easy_rec/python/main.py +878 -0
  140. easy_rec/python/model/__init__.py +0 -0
  141. easy_rec/python/model/autoint.py +73 -0
  142. easy_rec/python/model/cmbf.py +47 -0
  143. easy_rec/python/model/collaborative_metric_learning.py +182 -0
  144. easy_rec/python/model/custom_model.py +323 -0
  145. easy_rec/python/model/dat.py +138 -0
  146. easy_rec/python/model/dbmtl.py +116 -0
  147. easy_rec/python/model/dcn.py +70 -0
  148. easy_rec/python/model/deepfm.py +106 -0
  149. easy_rec/python/model/dlrm.py +73 -0
  150. easy_rec/python/model/dropoutnet.py +207 -0
  151. easy_rec/python/model/dssm.py +154 -0
  152. easy_rec/python/model/dssm_senet.py +143 -0
  153. easy_rec/python/model/dummy_model.py +48 -0
  154. easy_rec/python/model/easy_rec_estimator.py +739 -0
  155. easy_rec/python/model/easy_rec_model.py +467 -0
  156. easy_rec/python/model/esmm.py +242 -0
  157. easy_rec/python/model/fm.py +63 -0
  158. easy_rec/python/model/match_model.py +357 -0
  159. easy_rec/python/model/mind.py +445 -0
  160. easy_rec/python/model/mmoe.py +70 -0
  161. easy_rec/python/model/multi_task_model.py +303 -0
  162. easy_rec/python/model/multi_tower.py +62 -0
  163. easy_rec/python/model/multi_tower_bst.py +190 -0
  164. easy_rec/python/model/multi_tower_din.py +130 -0
  165. easy_rec/python/model/multi_tower_recall.py +68 -0
  166. easy_rec/python/model/pdn.py +203 -0
  167. easy_rec/python/model/ple.py +120 -0
  168. easy_rec/python/model/rank_model.py +485 -0
  169. easy_rec/python/model/rocket_launching.py +203 -0
  170. easy_rec/python/model/simple_multi_task.py +54 -0
  171. easy_rec/python/model/uniter.py +46 -0
  172. easy_rec/python/model/wide_and_deep.py +121 -0
  173. easy_rec/python/ops/1.12/incr_record.so +0 -0
  174. easy_rec/python/ops/1.12/kafka.so +0 -0
  175. easy_rec/python/ops/1.12/libcustom_ops.so +0 -0
  176. easy_rec/python/ops/1.12/libembed_op.so +0 -0
  177. easy_rec/python/ops/1.12/libhiredis.so.1.0.0 +0 -0
  178. easy_rec/python/ops/1.12/librdkafka++.so.1 +0 -0
  179. easy_rec/python/ops/1.12/librdkafka.so.1 +0 -0
  180. easy_rec/python/ops/1.12/libredis++.so +0 -0
  181. easy_rec/python/ops/1.12/libredis++.so.1 +0 -0
  182. easy_rec/python/ops/1.12/libredis++.so.1.2.3 +0 -0
  183. easy_rec/python/ops/1.12/libstr_avx_op.so +0 -0
  184. easy_rec/python/ops/1.12/libwrite_sparse_kv.so +0 -0
  185. easy_rec/python/ops/1.15/incr_record.so +0 -0
  186. easy_rec/python/ops/1.15/kafka.so +0 -0
  187. easy_rec/python/ops/1.15/libcustom_ops.so +0 -0
  188. easy_rec/python/ops/1.15/libembed_op.so +0 -0
  189. easy_rec/python/ops/1.15/libhiredis.so.1.0.0 +0 -0
  190. easy_rec/python/ops/1.15/librdkafka++.so +0 -0
  191. easy_rec/python/ops/1.15/librdkafka++.so.1 +0 -0
  192. easy_rec/python/ops/1.15/librdkafka.so +0 -0
  193. easy_rec/python/ops/1.15/librdkafka.so.1 +0 -0
  194. easy_rec/python/ops/1.15/libredis++.so.1 +0 -0
  195. easy_rec/python/ops/1.15/libstr_avx_op.so +0 -0
  196. easy_rec/python/ops/2.12/libcustom_ops.so +0 -0
  197. easy_rec/python/ops/2.12/libload_embed.so +0 -0
  198. easy_rec/python/ops/2.12/libstr_avx_op.so +0 -0
  199. easy_rec/python/ops/__init__.py +0 -0
  200. easy_rec/python/ops/gen_kafka_ops.py +193 -0
  201. easy_rec/python/ops/gen_str_avx_op.py +28 -0
  202. easy_rec/python/ops/incr_record.py +30 -0
  203. easy_rec/python/predict.py +170 -0
  204. easy_rec/python/protos/__init__.py +0 -0
  205. easy_rec/python/protos/autoint_pb2.py +122 -0
  206. easy_rec/python/protos/backbone_pb2.py +1416 -0
  207. easy_rec/python/protos/cmbf_pb2.py +435 -0
  208. easy_rec/python/protos/collaborative_metric_learning_pb2.py +252 -0
  209. easy_rec/python/protos/custom_model_pb2.py +57 -0
  210. easy_rec/python/protos/dat_pb2.py +262 -0
  211. easy_rec/python/protos/data_source_pb2.py +422 -0
  212. easy_rec/python/protos/dataset_pb2.py +1920 -0
  213. easy_rec/python/protos/dbmtl_pb2.py +191 -0
  214. easy_rec/python/protos/dcn_pb2.py +197 -0
  215. easy_rec/python/protos/deepfm_pb2.py +163 -0
  216. easy_rec/python/protos/dlrm_pb2.py +163 -0
  217. easy_rec/python/protos/dnn_pb2.py +329 -0
  218. easy_rec/python/protos/dropoutnet_pb2.py +239 -0
  219. easy_rec/python/protos/dssm_pb2.py +262 -0
  220. easy_rec/python/protos/dssm_senet_pb2.py +282 -0
  221. easy_rec/python/protos/easy_rec_model_pb2.py +1672 -0
  222. easy_rec/python/protos/esmm_pb2.py +133 -0
  223. easy_rec/python/protos/eval_pb2.py +930 -0
  224. easy_rec/python/protos/export_pb2.py +379 -0
  225. easy_rec/python/protos/feature_config_pb2.py +1359 -0
  226. easy_rec/python/protos/fm_pb2.py +90 -0
  227. easy_rec/python/protos/hive_config_pb2.py +138 -0
  228. easy_rec/python/protos/hyperparams_pb2.py +624 -0
  229. easy_rec/python/protos/keras_layer_pb2.py +692 -0
  230. easy_rec/python/protos/layer_pb2.py +1936 -0
  231. easy_rec/python/protos/loss_pb2.py +1713 -0
  232. easy_rec/python/protos/mind_pb2.py +497 -0
  233. easy_rec/python/protos/mmoe_pb2.py +215 -0
  234. easy_rec/python/protos/multi_tower_pb2.py +295 -0
  235. easy_rec/python/protos/multi_tower_recall_pb2.py +198 -0
  236. easy_rec/python/protos/optimizer_pb2.py +2017 -0
  237. easy_rec/python/protos/pdn_pb2.py +293 -0
  238. easy_rec/python/protos/pipeline_pb2.py +516 -0
  239. easy_rec/python/protos/ple_pb2.py +231 -0
  240. easy_rec/python/protos/predict_pb2.py +1140 -0
  241. easy_rec/python/protos/rocket_launching_pb2.py +169 -0
  242. easy_rec/python/protos/seq_encoder_pb2.py +1084 -0
  243. easy_rec/python/protos/simi_pb2.py +54 -0
  244. easy_rec/python/protos/simple_multi_task_pb2.py +97 -0
  245. easy_rec/python/protos/tf_predict_pb2.py +630 -0
  246. easy_rec/python/protos/tower_pb2.py +661 -0
  247. easy_rec/python/protos/train_pb2.py +1197 -0
  248. easy_rec/python/protos/uniter_pb2.py +307 -0
  249. easy_rec/python/protos/variational_dropout_pb2.py +91 -0
  250. easy_rec/python/protos/wide_and_deep_pb2.py +131 -0
  251. easy_rec/python/test/__init__.py +0 -0
  252. easy_rec/python/test/csv_input_test.py +340 -0
  253. easy_rec/python/test/custom_early_stop_func.py +19 -0
  254. easy_rec/python/test/dh_local_run.py +104 -0
  255. easy_rec/python/test/embed_test.py +155 -0
  256. easy_rec/python/test/emr_run.py +119 -0
  257. easy_rec/python/test/eval_metric_test.py +107 -0
  258. easy_rec/python/test/excel_convert_test.py +64 -0
  259. easy_rec/python/test/export_test.py +513 -0
  260. easy_rec/python/test/fg_test.py +70 -0
  261. easy_rec/python/test/hive_input_test.py +311 -0
  262. easy_rec/python/test/hpo_test.py +235 -0
  263. easy_rec/python/test/kafka_test.py +373 -0
  264. easy_rec/python/test/local_incr_test.py +122 -0
  265. easy_rec/python/test/loss_test.py +110 -0
  266. easy_rec/python/test/odps_command.py +61 -0
  267. easy_rec/python/test/odps_local_run.py +86 -0
  268. easy_rec/python/test/odps_run.py +254 -0
  269. easy_rec/python/test/odps_test_cls.py +39 -0
  270. easy_rec/python/test/odps_test_prepare.py +198 -0
  271. easy_rec/python/test/odps_test_util.py +237 -0
  272. easy_rec/python/test/pre_check_test.py +54 -0
  273. easy_rec/python/test/predictor_test.py +394 -0
  274. easy_rec/python/test/rtp_convert_test.py +133 -0
  275. easy_rec/python/test/run.py +138 -0
  276. easy_rec/python/test/train_eval_test.py +1299 -0
  277. easy_rec/python/test/util_test.py +85 -0
  278. easy_rec/python/test/zero_inflated_lognormal_test.py +53 -0
  279. easy_rec/python/tools/__init__.py +0 -0
  280. easy_rec/python/tools/add_boundaries_to_config.py +67 -0
  281. easy_rec/python/tools/add_feature_info_to_config.py +145 -0
  282. easy_rec/python/tools/convert_config_format.py +48 -0
  283. easy_rec/python/tools/convert_rtp_data.py +79 -0
  284. easy_rec/python/tools/convert_rtp_fg.py +106 -0
  285. easy_rec/python/tools/create_config_from_excel.py +427 -0
  286. easy_rec/python/tools/criteo/__init__.py +0 -0
  287. easy_rec/python/tools/criteo/convert_data.py +157 -0
  288. easy_rec/python/tools/edit_lookup_graph.py +134 -0
  289. easy_rec/python/tools/faiss_index_pai.py +116 -0
  290. easy_rec/python/tools/feature_selection.py +316 -0
  291. easy_rec/python/tools/hit_rate_ds.py +223 -0
  292. easy_rec/python/tools/hit_rate_pai.py +138 -0
  293. easy_rec/python/tools/pre_check.py +120 -0
  294. easy_rec/python/tools/predict_and_chk.py +111 -0
  295. easy_rec/python/tools/read_kafka.py +55 -0
  296. easy_rec/python/tools/split_model_pai.py +286 -0
  297. easy_rec/python/tools/split_pdn_model_pai.py +272 -0
  298. easy_rec/python/tools/test_saved_model.py +80 -0
  299. easy_rec/python/tools/view_saved_model.py +39 -0
  300. easy_rec/python/tools/write_kafka.py +65 -0
  301. easy_rec/python/train_eval.py +325 -0
  302. easy_rec/python/utils/__init__.py +15 -0
  303. easy_rec/python/utils/activation.py +120 -0
  304. easy_rec/python/utils/check_utils.py +87 -0
  305. easy_rec/python/utils/compat.py +14 -0
  306. easy_rec/python/utils/config_util.py +652 -0
  307. easy_rec/python/utils/constant.py +43 -0
  308. easy_rec/python/utils/convert_rtp_fg.py +616 -0
  309. easy_rec/python/utils/dag.py +192 -0
  310. easy_rec/python/utils/distribution_utils.py +268 -0
  311. easy_rec/python/utils/ds_util.py +65 -0
  312. easy_rec/python/utils/embedding_utils.py +73 -0
  313. easy_rec/python/utils/estimator_utils.py +1036 -0
  314. easy_rec/python/utils/export_big_model.py +630 -0
  315. easy_rec/python/utils/expr_util.py +118 -0
  316. easy_rec/python/utils/fg_util.py +53 -0
  317. easy_rec/python/utils/hit_rate_utils.py +220 -0
  318. easy_rec/python/utils/hive_utils.py +183 -0
  319. easy_rec/python/utils/hpo_util.py +137 -0
  320. easy_rec/python/utils/hvd_utils.py +56 -0
  321. easy_rec/python/utils/input_utils.py +108 -0
  322. easy_rec/python/utils/io_util.py +282 -0
  323. easy_rec/python/utils/load_class.py +249 -0
  324. easy_rec/python/utils/meta_graph_editor.py +941 -0
  325. easy_rec/python/utils/multi_optimizer.py +62 -0
  326. easy_rec/python/utils/numpy_utils.py +18 -0
  327. easy_rec/python/utils/odps_util.py +79 -0
  328. easy_rec/python/utils/pai_util.py +86 -0
  329. easy_rec/python/utils/proto_util.py +90 -0
  330. easy_rec/python/utils/restore_filter.py +89 -0
  331. easy_rec/python/utils/shape_utils.py +432 -0
  332. easy_rec/python/utils/static_shape.py +71 -0
  333. easy_rec/python/utils/test_utils.py +866 -0
  334. easy_rec/python/utils/tf_utils.py +56 -0
  335. easy_rec/version.py +4 -0
  336. test/__init__.py +0 -0
@@ -0,0 +1,3702 @@
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Implementation of tf.metrics module."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ from tensorflow.python.eager import context
22
+ from tensorflow.python.framework import dtypes
23
+ from tensorflow.python.framework import ops
24
+ from tensorflow.python.framework import sparse_tensor
25
+ from tensorflow.python.ops import array_ops
26
+ from tensorflow.python.ops import check_ops
27
+ from tensorflow.python.ops import confusion_matrix
28
+ from tensorflow.python.ops import control_flow_ops
29
+ from tensorflow.python.ops import math_ops
30
+ from tensorflow.python.ops import nn
31
+ from tensorflow.python.ops import sets
32
+ from tensorflow.python.ops import sparse_ops
33
+ from tensorflow.python.ops import state_ops
34
+ from tensorflow.python.ops import variable_scope
35
+ from tensorflow.python.ops import weights_broadcast_ops
36
+ from tensorflow.python.platform import tf_logging as logging
37
+ from tensorflow.python.training import distribution_strategy_context
38
+ from tensorflow.python.util.deprecation import deprecated
39
+ from tensorflow.python.util.tf_export import tf_export
40
+
41
+
42
+ def metric_variable(shape, dtype, validate_shape=True, name=None):
43
+ """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
44
+
45
+ If running in a `DistributionStrategy` context, the variable will be
46
+ "tower local". This means:
47
+
48
+ * The returned object will be a container with separate variables
49
+ per replica/tower of the model.
50
+
51
+ * When writing to the variable, e.g. using `assign_add` in a metric
52
+ update, the update will be applied to the variable local to the
53
+ replica/tower.
54
+
55
+ * To get a metric's result value, we need to sum the variable values
56
+ across the replicas/towers before computing the final answer.
57
+ Furthermore, the final answer should be computed once instead of
58
+ in every replica/tower. Both of these are accomplished by
59
+ running the computation of the final result value inside
60
+ `tf.contrib.distribution_strategy_context.get_tower_context(
61
+ ).merge_call(fn)`.
62
+ Inside the `merge_call()`, ops are only added to the graph once
63
+ and access to a tower-local variable in a computation returns
64
+ the sum across all replicas/towers.
65
+
66
+ Args:
67
+ shape: Shape of the created variable.
68
+ dtype: Type of the created variable.
69
+ validate_shape: (Optional) Whether shape validation is enabled for
70
+ the created variable.
71
+ name: (Optional) String name of the created variable.
72
+
73
+ Returns:
74
+ A (non-trainable) variable initialized to zero, or if inside a
75
+ `DistributionStrategy` scope a tower-local variable container.
76
+ """
77
+ # Note that synchronization "ON_READ" implies trainable=False.
78
+ return variable_scope.variable(
79
+ lambda: array_ops.zeros(shape, dtype),
80
+ collections=[
81
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
82
+ ],
83
+ validate_shape=validate_shape,
84
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
85
+ aggregation=variable_scope.VariableAggregation.SUM,
86
+ name=name)
87
+
88
+
89
+ def _remove_squeezable_dimensions(predictions, labels, weights):
90
+ """Squeeze or expand last dim if needed.
91
+
92
+ Squeezes last dim of `predictions` or `labels` if their rank differs by 1
93
+ (using confusion_matrix.remove_squeezable_dimensions).
94
+ Squeezes or expands last dim of `weights` if its rank differs by 1 from the
95
+ new rank of `predictions`.
96
+
97
+ If `weights` is scalar, it is kept scalar.
98
+
99
+ This will use static shape if available. Otherwise, it will add graph
100
+ operations, which could result in a performance hit.
101
+
102
+ Args:
103
+ predictions: Predicted values, a `Tensor` of arbitrary dimensions.
104
+ labels: Optional label `Tensor` whose dimensions match `predictions`.
105
+ weights: Optional weight scalar or `Tensor` whose dimensions match
106
+ `predictions`.
107
+
108
+ Returns:
109
+ Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
110
+ the last dimension squeezed, `weights` could be extended by one dimension.
111
+ """
112
+ predictions = ops.convert_to_tensor(predictions)
113
+ if labels is not None:
114
+ labels, predictions = confusion_matrix.remove_squeezable_dimensions(
115
+ labels, predictions)
116
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
117
+
118
+ if weights is None:
119
+ return predictions, labels, None
120
+
121
+ weights = ops.convert_to_tensor(weights)
122
+ weights_shape = weights.get_shape()
123
+ weights_rank = weights_shape.ndims
124
+ if weights_rank == 0:
125
+ return predictions, labels, weights
126
+
127
+ predictions_shape = predictions.get_shape()
128
+ predictions_rank = predictions_shape.ndims
129
+ if (predictions_rank is not None) and (weights_rank is not None):
130
+ # Use static rank.
131
+ if weights_rank - predictions_rank == 1:
132
+ weights = array_ops.squeeze(weights, [-1])
133
+ elif predictions_rank - weights_rank == 1:
134
+ weights = array_ops.expand_dims(weights, [-1])
135
+ else:
136
+ # Use dynamic rank.
137
+ weights_rank_tensor = array_ops.rank(weights)
138
+ rank_diff = weights_rank_tensor - array_ops.rank(predictions)
139
+
140
+ def _maybe_expand_weights():
141
+ return control_flow_ops.cond(
142
+ math_ops.equal(rank_diff, -1),
143
+ lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
144
+
145
+ # Don't attempt squeeze if it will fail based on static check.
146
+ if ((weights_rank is not None) and
147
+ (not weights_shape.dims[-1].is_compatible_with(1))):
148
+ maybe_squeeze_weights = lambda: weights # noqa: E731
149
+ else:
150
+ maybe_squeeze_weights = lambda: array_ops.squeeze( # noqa: E731
151
+ weights, [-1]) # noqa: E126
152
+
153
+ def _maybe_adjust_weights():
154
+ return control_flow_ops.cond(
155
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
156
+ _maybe_expand_weights)
157
+
158
+ # If weights are scalar, do nothing. Otherwise, try to add or remove a
159
+ # dimension to match predictions.
160
+ weights = control_flow_ops.cond(
161
+ math_ops.equal(weights_rank_tensor, 0), lambda: weights,
162
+ _maybe_adjust_weights)
163
+ return predictions, labels, weights
164
+
165
+
166
+ def _maybe_expand_labels(labels, predictions):
167
+ """If necessary, expand `labels` along last dimension to match `predictions`.
168
+
169
+ Args:
170
+ labels: `Tensor` or `SparseTensor` with shape
171
+ [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
172
+ num_labels=1, in which case the result is an expanded `labels` with shape
173
+ [D1, ... DN, 1].
174
+ predictions: `Tensor` with shape [D1, ... DN, num_classes].
175
+
176
+ Returns:
177
+ `labels` with the same rank as `predictions`.
178
+
179
+ Raises:
180
+ ValueError: if `labels` has invalid shape.
181
+ """
182
+ with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
183
+ labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
184
+
185
+ # If sparse, expand sparse shape.
186
+ if isinstance(labels, sparse_tensor.SparseTensor):
187
+ return control_flow_ops.cond(
188
+ math_ops.equal(
189
+ array_ops.rank(predictions),
190
+ array_ops.size(labels.dense_shape) + 1),
191
+ lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda
192
+ labels,
193
+ shape=array_ops.concat((labels.dense_shape, (1,)), 0),
194
+ name=scope),
195
+ lambda: labels)
196
+
197
+ # Otherwise, try to use static shape.
198
+ labels_rank = labels.get_shape().ndims
199
+ if labels_rank is not None:
200
+ predictions_rank = predictions.get_shape().ndims
201
+ if predictions_rank is not None:
202
+ if predictions_rank == labels_rank:
203
+ return labels
204
+ if predictions_rank == labels_rank + 1:
205
+ return array_ops.expand_dims(labels, -1, name=scope)
206
+ raise ValueError(
207
+ 'Unexpected labels shape %s for predictions shape %s.' %
208
+ (labels.get_shape(), predictions.get_shape()))
209
+
210
+ # Otherwise, use dynamic shape.
211
+ return control_flow_ops.cond(
212
+ math_ops.equal(array_ops.rank(predictions),
213
+ array_ops.rank(labels) + 1),
214
+ lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
215
+
216
+
217
+ def _safe_div(numerator, denominator, name):
218
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
219
+
220
+ Args:
221
+ numerator: A real `Tensor`.
222
+ denominator: A real `Tensor`, with dtype matching `numerator`.
223
+ name: Name for the returned op.
224
+
225
+ Returns:
226
+ 0 if `denominator` <= 0, else `numerator` / `denominator`
227
+ """
228
+ t = math_ops.truediv(numerator, denominator)
229
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
230
+ condition = math_ops.greater(denominator, zero)
231
+ zero = math_ops.cast(zero, t.dtype)
232
+ return array_ops.where(condition, t, zero, name=name)
233
+
234
+
235
+ def _safe_scalar_div(numerator, denominator, name):
236
+ """Divides two values, returning 0 if the denominator is 0.
237
+
238
+ Args:
239
+ numerator: A scalar `float64` `Tensor`.
240
+ denominator: A scalar `float64` `Tensor`.
241
+ name: Name for the returned op.
242
+
243
+ Returns:
244
+ 0 if `denominator` == 0, else `numerator` / `denominator`
245
+ """
246
+ numerator.get_shape().with_rank_at_most(1)
247
+ denominator.get_shape().with_rank_at_most(1)
248
+ return control_flow_ops.cond(
249
+ math_ops.equal(
250
+ array_ops.constant(0.0, dtype=dtypes.float64), denominator),
251
+ lambda: array_ops.constant(0.0, dtype=dtypes.float64),
252
+ lambda: math_ops.div(numerator, denominator),
253
+ name=name)
254
+
255
+
256
+ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
257
+ """Calculate a streaming confusion matrix.
258
+
259
+ Calculates a confusion matrix. For estimation over a stream of data,
260
+ the function creates an `update_op` operation.
261
+
262
+ Args:
263
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
264
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
265
+ predictions: A `Tensor` of prediction results for semantic labels, whose
266
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
267
+ flattened if its rank > 1.
268
+ num_classes: The possible number of labels the prediction task can
269
+ have. This value must be provided, since a confusion matrix of
270
+ dimension = [num_classes, num_classes] will be allocated.
271
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
272
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
273
+ be either `1`, or the same as the corresponding `labels` dimension).
274
+
275
+ Returns:
276
+ total_cm: A `Tensor` representing the confusion matrix.
277
+ update_op: An operation that increments the confusion matrix.
278
+ """
279
+ # Local variable to accumulate the predictions in the confusion matrix.
280
+ total_cm = metric_variable([num_classes, num_classes],
281
+ dtypes.float64,
282
+ name='total_confusion_matrix')
283
+
284
+ # Cast the type to int64 required by confusion_matrix_ops.
285
+ predictions = math_ops.to_int64(predictions)
286
+ labels = math_ops.to_int64(labels)
287
+ num_classes = math_ops.to_int64(num_classes)
288
+
289
+ # Flatten the input if its rank > 1.
290
+ if predictions.get_shape().ndims > 1:
291
+ predictions = array_ops.reshape(predictions, [-1])
292
+
293
+ if labels.get_shape().ndims > 1:
294
+ labels = array_ops.reshape(labels, [-1])
295
+
296
+ if (weights is not None) and (weights.get_shape().ndims > 1):
297
+ weights = array_ops.reshape(weights, [-1])
298
+
299
+ # Accumulate the prediction to current confusion matrix.
300
+ current_cm = confusion_matrix.confusion_matrix(
301
+ labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
302
+ update_op = state_ops.assign_add(total_cm, current_cm, use_locking=True)
303
+ return total_cm, update_op
304
+
305
+
306
+ def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
307
+ """Aggregate metric value across towers."""
308
+
309
+ def fn(distribution, *a):
310
+ """Call `metric_value_fn` in the correct control flow context."""
311
+ if hasattr(distribution, '_outer_control_flow_context'):
312
+ # If there was an outer context captured before this method was called,
313
+ # then we enter that context to create the metric value op. If the
314
+ # caputred context is `None`, ops.control_dependencies(None) gives the
315
+ # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
316
+ # captured context.
317
+ # This special handling is needed because sometimes the metric is created
318
+ # inside a while_loop (and perhaps a TPU rewrite context). But we don't
319
+ # want the value op to be evaluated every step or on the TPU. So we
320
+ # create it outside so that it can be evaluated at the end on the host,
321
+ # once the update ops have been evaluted.
322
+
323
+ # pylint: disable=protected-access
324
+ if distribution._outer_control_flow_context is None:
325
+ with ops.control_dependencies(None):
326
+ metric_value = metric_value_fn(distribution, *a)
327
+ else:
328
+ distribution._outer_control_flow_context.Enter()
329
+ metric_value = metric_value_fn(distribution, *a)
330
+ distribution._outer_control_flow_context.Exit()
331
+ # pylint: enable=protected-access
332
+ else:
333
+ metric_value = metric_value_fn(distribution, *a)
334
+ if metrics_collections:
335
+ ops.add_to_collections(metrics_collections, metric_value)
336
+ return metric_value
337
+
338
+ return distribution_strategy_context.get_tower_context().merge_call(fn, *args)
339
+
340
+
341
+ @tf_export('metrics.mean')
342
+ def mean(values,
343
+ weights=None,
344
+ metrics_collections=None,
345
+ updates_collections=None,
346
+ name=None):
347
+ """Computes the (weighted) mean of the given values.
348
+
349
+ The `mean` function creates two local variables, `total` and `count`
350
+ that are used to compute the average of `values`. This average is ultimately
351
+ returned as `mean` which is an idempotent operation that simply divides
352
+ `total` by `count`.
353
+
354
+ For estimation of the metric over a stream of data, the function creates an
355
+ `update_op` operation that updates these variables and returns the `mean`.
356
+ `update_op` increments `total` with the reduced sum of the product of `values`
357
+ and `weights`, and it increments `count` with the reduced sum of `weights`.
358
+
359
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
360
+
361
+ Args:
362
+ values: A `Tensor` of arbitrary dimensions.
363
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
364
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
365
+ be either `1`, or the same as the corresponding `values` dimension).
366
+ metrics_collections: An optional list of collections that `mean`
367
+ should be added to.
368
+ updates_collections: An optional list of collections that `update_op`
369
+ should be added to.
370
+ name: An optional variable_scope name.
371
+
372
+ Returns:
373
+ mean: A `Tensor` representing the current mean, the value of `total` divided
374
+ by `count`.
375
+ update_op: An operation that increments the `total` and `count` variables
376
+ appropriately and whose value matches `mean_value`.
377
+
378
+ Raises:
379
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
380
+ or if either `metrics_collections` or `updates_collections` are not a list
381
+ or tuple.
382
+ RuntimeError: If eager execution is enabled.
383
+ """
384
+ if context.executing_eagerly():
385
+ raise RuntimeError('tf.metrics.mean is not supported when eager execution '
386
+ 'is enabled.')
387
+
388
+ with variable_scope.variable_scope(name, 'mean', (values, weights)):
389
+ values = math_ops.to_float(values)
390
+
391
+ total = metric_variable([], dtypes.float32, name='total')
392
+ count = metric_variable([], dtypes.float32, name='count')
393
+
394
+ if weights is None:
395
+ num_values = math_ops.to_float(array_ops.size(values))
396
+ else:
397
+ values, _, weights = _remove_squeezable_dimensions(
398
+ predictions=values, labels=None, weights=weights)
399
+ weights = weights_broadcast_ops.broadcast_weights(
400
+ math_ops.to_float(weights), values)
401
+ values = math_ops.multiply(values, weights)
402
+ num_values = math_ops.reduce_sum(weights)
403
+
404
+ update_total_op = state_ops.assign_add(
405
+ total, math_ops.reduce_sum(values), use_locking=True)
406
+ with ops.control_dependencies([values]):
407
+ update_count_op = state_ops.assign_add(
408
+ count, num_values, use_locking=True)
409
+
410
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value') # noqa: E731
411
+
412
+ mean_t = _aggregate_across_towers(metrics_collections, compute_mean, total,
413
+ count)
414
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
415
+
416
+ if updates_collections:
417
+ ops.add_to_collections(updates_collections, update_op)
418
+
419
+ return mean_t, update_op
420
+
421
+
422
+ @tf_export('metrics.accuracy')
423
+ def accuracy(labels,
424
+ predictions,
425
+ weights=None,
426
+ metrics_collections=None,
427
+ updates_collections=None,
428
+ name=None):
429
+ """Calculates how often `predictions` matches `labels`.
430
+
431
+ The `accuracy` function creates two local variables, `total` and
432
+ `count` that are used to compute the frequency with which `predictions`
433
+ matches `labels`. This frequency is ultimately returned as `accuracy`: an
434
+ idempotent operation that simply divides `total` by `count`.
435
+
436
+ For estimation of the metric over a stream of data, the function creates an
437
+ `update_op` operation that updates these variables and returns the `accuracy`.
438
+ Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
439
+ where the corresponding elements of `predictions` and `labels` match and 0.0
440
+ otherwise. Then `update_op` increments `total` with the reduced sum of the
441
+ product of `weights` and `is_correct`, and it increments `count` with the
442
+ reduced sum of `weights`.
443
+
444
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
445
+
446
+ Args:
447
+ labels: The ground truth values, a `Tensor` whose shape matches
448
+ `predictions`.
449
+ predictions: The predicted values, a `Tensor` of any shape.
450
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
451
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
452
+ be either `1`, or the same as the corresponding `labels` dimension).
453
+ metrics_collections: An optional list of collections that `accuracy` should
454
+ be added to.
455
+ updates_collections: An optional list of collections that `update_op` should
456
+ be added to.
457
+ name: An optional variable_scope name.
458
+
459
+ Returns:
460
+ accuracy: A `Tensor` representing the accuracy, the value of `total` divided
461
+ by `count`.
462
+ update_op: An operation that increments the `total` and `count` variables
463
+ appropriately and whose value matches `accuracy`.
464
+
465
+ Raises:
466
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
467
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
468
+ either `metrics_collections` or `updates_collections` are not a list or
469
+ tuple.
470
+ RuntimeError: If eager execution is enabled.
471
+ """
472
+ if context.executing_eagerly():
473
+ raise RuntimeError('tf.metrics.accuracy is not supported when eager '
474
+ 'execution is enabled.')
475
+
476
+ predictions, labels, weights = _remove_squeezable_dimensions(
477
+ predictions=predictions, labels=labels, weights=weights)
478
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
479
+ if labels.dtype != predictions.dtype:
480
+ predictions = math_ops.cast(predictions, labels.dtype)
481
+ is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
482
+ return mean(is_correct, weights, metrics_collections, updates_collections,
483
+ name or 'accuracy')
484
+
485
+
486
+ def _confusion_matrix_at_thresholds(labels,
487
+ predictions,
488
+ thresholds,
489
+ weights=None,
490
+ includes=None):
491
+ """Computes true_positives, false_negatives, true_negatives, false_positives.
492
+
493
+ This function creates up to four local variables, `true_positives`,
494
+ `true_negatives`, `false_positives` and `false_negatives`.
495
+ `true_positive[i]` is defined as the total weight of values in `predictions`
496
+ above `thresholds[i]` whose corresponding entry in `labels` is `True`.
497
+ `false_negatives[i]` is defined as the total weight of values in `predictions`
498
+ at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
499
+ `true_negatives[i]` is defined as the total weight of values in `predictions`
500
+ at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
501
+ `false_positives[i]` is defined as the total weight of values in `predictions`
502
+ above `thresholds[i]` whose corresponding entry in `labels` is `False`.
503
+
504
+ For estimation of these metrics over a stream of data, for each metric the
505
+ function respectively creates an `update_op` operation that updates the
506
+ variable and returns its value.
507
+
508
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
509
+
510
+ Args:
511
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
512
+ `bool`.
513
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
514
+ are in the range `[0, 1]`.
515
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
516
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
517
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
518
+ be either `1`, or the same as the corresponding `labels` dimension).
519
+ includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
520
+ default to all four.
521
+
522
+ Returns:
523
+ values: Dict of variables of shape `[len(thresholds)]`. Keys are from
524
+ `includes`.
525
+ update_ops: Dict of operations that increments the `values`. Keys are from
526
+ `includes`.
527
+
528
+ Raises:
529
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
530
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
531
+ `includes` contains invalid keys.
532
+ """
533
+ all_includes = ('tp', 'fn', 'tn', 'fp')
534
+ if includes is None:
535
+ includes = all_includes
536
+ else:
537
+ for include in includes:
538
+ if include not in all_includes:
539
+ raise ValueError('Invalid key: %s.' % include)
540
+
541
+ with ops.control_dependencies([
542
+ check_ops.assert_greater_equal(
543
+ predictions,
544
+ math_ops.cast(0.0, dtype=predictions.dtype),
545
+ message='predictions must be in [0, 1]'),
546
+ check_ops.assert_less_equal(
547
+ predictions,
548
+ math_ops.cast(1.0, dtype=predictions.dtype),
549
+ message='predictions must be in [0, 1]')
550
+ ]):
551
+ predictions, labels, weights = _remove_squeezable_dimensions(
552
+ predictions=math_ops.to_float(predictions),
553
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
554
+ weights=weights)
555
+
556
+ num_thresholds = len(thresholds)
557
+
558
+ # Reshape predictions and labels.
559
+ predictions_2d = array_ops.reshape(predictions, [-1, 1])
560
+ labels_2d = array_ops.reshape(
561
+ math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
562
+
563
+ # Use static shape if known.
564
+ num_predictions = predictions_2d.get_shape().as_list()[0]
565
+
566
+ # Otherwise use dynamic shape.
567
+ if num_predictions is None:
568
+ num_predictions = array_ops.shape(predictions_2d)[0]
569
+ thresh_tiled = array_ops.tile(
570
+ array_ops.expand_dims(array_ops.constant(thresholds), [1]),
571
+ array_ops.stack([1, num_predictions]))
572
+
573
+ # Tile the predictions after thresholding them across different thresholds.
574
+ pred_is_pos = math_ops.greater(
575
+ array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
576
+ thresh_tiled)
577
+ if ('fn' in includes) or ('tn' in includes):
578
+ pred_is_neg = math_ops.logical_not(pred_is_pos)
579
+
580
+ # Tile labels by number of thresholds
581
+ label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
582
+ if ('fp' in includes) or ('tn' in includes):
583
+ label_is_neg = math_ops.logical_not(label_is_pos)
584
+
585
+ if weights is not None:
586
+ weights = weights_broadcast_ops.broadcast_weights(
587
+ math_ops.to_float(weights), predictions)
588
+ weights_tiled = array_ops.tile(
589
+ array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
590
+ thresh_tiled.get_shape().assert_is_compatible_with(
591
+ weights_tiled.get_shape())
592
+ else:
593
+ weights_tiled = None
594
+
595
+ values = {}
596
+ update_ops = {}
597
+
598
+ if 'tp' in includes:
599
+ true_p = metric_variable([num_thresholds],
600
+ dtypes.float32,
601
+ name='true_positives')
602
+ is_true_positive = math_ops.cast(
603
+ math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
604
+ if weights_tiled is not None:
605
+ is_true_positive *= weights_tiled
606
+ update_ops['tp'] = state_ops.assign_add(
607
+ true_p, math_ops.reduce_sum(is_true_positive, 1), use_locking=True)
608
+ values['tp'] = true_p
609
+
610
+ if 'fn' in includes:
611
+ false_n = metric_variable([num_thresholds],
612
+ dtypes.float32,
613
+ name='false_negatives')
614
+ is_false_negative = math_ops.cast(
615
+ math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
616
+ if weights_tiled is not None:
617
+ is_false_negative *= weights_tiled
618
+ update_ops['fn'] = state_ops.assign_add(
619
+ false_n, math_ops.reduce_sum(is_false_negative, 1), use_locking=True)
620
+ values['fn'] = false_n
621
+
622
+ if 'tn' in includes:
623
+ true_n = metric_variable([num_thresholds],
624
+ dtypes.float32,
625
+ name='true_negatives')
626
+ is_true_negative = math_ops.cast(
627
+ math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
628
+ if weights_tiled is not None:
629
+ is_true_negative *= weights_tiled
630
+ update_ops['tn'] = state_ops.assign_add(
631
+ true_n, math_ops.reduce_sum(is_true_negative, 1), use_locking=True)
632
+ values['tn'] = true_n
633
+
634
+ if 'fp' in includes:
635
+ false_p = metric_variable([num_thresholds],
636
+ dtypes.float32,
637
+ name='false_positives')
638
+ is_false_positive = math_ops.cast(
639
+ math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
640
+ if weights_tiled is not None:
641
+ is_false_positive *= weights_tiled
642
+ update_ops['fp'] = state_ops.assign_add(
643
+ false_p, math_ops.reduce_sum(is_false_positive, 1), use_locking=True)
644
+ values['fp'] = false_p
645
+
646
+ return values, update_ops
647
+
648
+
649
+ def _aggregate_variable(v, collections):
650
+ f = lambda distribution, value: distribution.read_var(value) # noqa: E731
651
+ return _aggregate_across_towers(collections, f, v)
652
+
653
+
654
+ @tf_export('metrics.auc')
655
+ def auc(labels,
656
+ predictions,
657
+ weights=None,
658
+ num_thresholds=200,
659
+ metrics_collections=None,
660
+ updates_collections=None,
661
+ curve='ROC',
662
+ name=None,
663
+ summation_method='trapezoidal'):
664
+ """Computes the approximate AUC via a Riemann sum.
665
+
666
+ The `auc` function creates four local variables, `true_positives`,
667
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
668
+ compute the AUC. To discretize the AUC curve, a linearly spaced set of
669
+ thresholds is used to compute pairs of recall and precision values. The area
670
+ under the ROC-curve is therefore computed using the height of the recall
671
+ values by the false positive rate, while the area under the PR-curve is the
672
+ computed using the height of the precision values by the recall.
673
+
674
+ This value is ultimately returned as `auc`, an idempotent operation that
675
+ computes the area under a discretized curve of precision versus recall values
676
+ (computed using the aforementioned variables). The `num_thresholds` variable
677
+ controls the degree of discretization with larger numbers of thresholds more
678
+ closely approximating the true AUC. The quality of the approximation may vary
679
+ dramatically depending on `num_thresholds`.
680
+
681
+ For best results, `predictions` should be distributed approximately uniformly
682
+ in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
683
+ approximation may be poor if this is not the case. Setting `summation_method`
684
+ to 'minoring' or 'majoring' can help quantify the error in the approximation
685
+ by providing lower or upper bound estimate of the AUC.
686
+
687
+ For estimation of the metric over a stream of data, the function creates an
688
+ `update_op` operation that updates these variables and returns the `auc`.
689
+
690
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
691
+
692
+ Args:
693
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
694
+ `bool`.
695
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
696
+ are in the range `[0, 1]`.
697
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
698
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
699
+ be either `1`, or the same as the corresponding `labels` dimension).
700
+ num_thresholds: The number of thresholds to use when discretizing the roc
701
+ curve.
702
+ metrics_collections: An optional list of collections that `auc` should be
703
+ added to.
704
+ updates_collections: An optional list of collections that `update_op` should
705
+ be added to.
706
+ curve: Specifies the name of the curve to be computed, 'ROC' [default] or
707
+ 'PR' for the Precision-Recall-curve.
708
+ name: An optional variable_scope name.
709
+ summation_method: Specifies the Riemann summation method used
710
+ (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
711
+ applies the trapezoidal rule; 'careful_interpolation', a variant of it
712
+ differing only by a more correct interpolation scheme for PR-AUC -
713
+ interpolating (true/false) positives but not the ratio that is precision;
714
+ 'minoring' that applies left summation for increasing intervals and right
715
+ summation for decreasing intervals; 'majoring' that does the opposite.
716
+ Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
717
+ (to be deprecated soon) as it applies the same method for ROC, and a
718
+ better one (see Davis & Goadrich 2006 for details) for the PR curve.
719
+
720
+ Returns:
721
+ auc: A scalar `Tensor` representing the current area-under-curve.
722
+ update_op: An operation that increments the `true_positives`,
723
+ `true_negatives`, `false_positives` and `false_negatives` variables
724
+ appropriately and whose value matches `auc`.
725
+
726
+ Raises:
727
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
728
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
729
+ either `metrics_collections` or `updates_collections` are not a list or
730
+ tuple.
731
+ RuntimeError: If eager execution is enabled.
732
+ """
733
+ print('use_distribute_pai_auc')
734
+ logging.info('use_distribute_pai_auc')
735
+ if context.executing_eagerly():
736
+ raise RuntimeError('tf.metrics.auc is not supported when eager execution '
737
+ 'is enabled.')
738
+
739
+ with variable_scope.variable_scope(name, 'auc',
740
+ (labels, predictions, weights)):
741
+ if curve != 'ROC' and curve != 'PR':
742
+ raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
743
+ kepsilon = 1e-7 # to account for floating point imprecisions
744
+ thresholds = [
745
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
746
+ ]
747
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
748
+
749
+ values, update_ops = _confusion_matrix_at_thresholds(
750
+ labels, predictions, thresholds, weights)
751
+
752
+ # Add epsilons to avoid dividing by 0.
753
+ epsilon = 1.0e-6
754
+
755
+ def interpolate_pr_auc(tp, fp, fn):
756
+ """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
757
+
758
+ Note here we derive & use a closed formula not present in the paper
759
+ - as follows:
760
+ Modeling all of TP (true positive weight),
761
+ FP (false positive weight) and their sum P = TP + FP (positive weight)
762
+ as varying linearly within each interval [A, B] between successive
763
+ thresholds, we get
764
+ Precision = (TP_A + slope * (P - P_A)) / P
765
+ with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
766
+ The area within the interval is thus (slope / total_pos_weight) times
767
+ int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
768
+ int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
769
+ where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
770
+ int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
771
+ Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
772
+ slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight
773
+ where dTP == TP_B - TP_A.
774
+ Note that when P_A == 0 the above calculation simplifies into
775
+ int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
776
+ which is really equivalent to imputing constant precision throughout the
777
+ first bucket having >0 true positives.
778
+
779
+ Args:
780
+ tp: true positive counts
781
+ fp: false positive counts
782
+ fn: false negative counts
783
+ Returns:
784
+ pr_auc: an approximation of the area under the P-R curve.
785
+ """
786
+ dtp = tp[:num_thresholds - 1] - tp[1:]
787
+ p = tp + fp
788
+ prec_slope = _safe_div(dtp, p[:num_thresholds - 1] - p[1:], 'prec_slope')
789
+ intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
790
+ safe_p_ratio = array_ops.where(
791
+ math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
792
+ _safe_div(p[:num_thresholds - 1], p[1:], 'recall_relative_ratio'),
793
+ array_ops.ones_like(p[1:]))
794
+ return math_ops.reduce_sum(
795
+ _safe_div(
796
+ prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
797
+ tp[1:] + fn[1:],
798
+ name='pr_auc_increment'),
799
+ name='interpolate_pr_auc')
800
+
801
+ def compute_auc(tp, fn, tn, fp, name):
802
+ """Computes the roc-auc or pr-auc based on confusion counts."""
803
+ if curve == 'PR':
804
+ if summation_method == 'trapezoidal':
805
+ logging.warning(
806
+ 'Trapezoidal rule is known to produce incorrect PR-AUCs; '
807
+ 'please switch to "careful_interpolation" instead.')
808
+ elif summation_method == 'careful_interpolation':
809
+ # This one is a bit tricky and is handled separately.
810
+ return interpolate_pr_auc(tp, fp, fn)
811
+ rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
812
+ if curve == 'ROC':
813
+ fp_rate = math_ops.div(fp, fp + tn + epsilon)
814
+ x = fp_rate
815
+ y = rec
816
+ else: # curve == 'PR'.
817
+ prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
818
+ x = rec
819
+ y = prec
820
+ if summation_method in ('trapezoidal', 'careful_interpolation'):
821
+ # Note that the case ('PR', 'careful_interpolation') has been handled
822
+ # above.
823
+ return math_ops.reduce_sum(
824
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
825
+ (y[:num_thresholds - 1] + y[1:]) / 2.),
826
+ name=name)
827
+ elif summation_method == 'minoring':
828
+ return math_ops.reduce_sum(
829
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
830
+ math_ops.minimum(y[:num_thresholds - 1], y[1:])),
831
+ name=name)
832
+ elif summation_method == 'majoring':
833
+ return math_ops.reduce_sum(
834
+ math_ops.multiply(x[:num_thresholds - 1] - x[1:],
835
+ math_ops.maximum(y[:num_thresholds - 1], y[1:])),
836
+ name=name)
837
+ else:
838
+ raise ValueError('Invalid summation_method: %s' % summation_method)
839
+
840
+ # sum up the areas of all the trapeziums
841
+ def compute_auc_value(_, values):
842
+ return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
843
+ 'value')
844
+
845
+ auc_value = _aggregate_across_towers(metrics_collections, compute_auc_value,
846
+ values)
847
+ update_op = compute_auc(update_ops['tp'], update_ops['fn'],
848
+ update_ops['tn'], update_ops['fp'], 'update_op')
849
+
850
+ if updates_collections:
851
+ ops.add_to_collections(updates_collections, update_op)
852
+
853
+ return auc_value, update_op
854
+
855
+
856
+ @tf_export('metrics.mean_absolute_error')
857
+ def mean_absolute_error(labels,
858
+ predictions,
859
+ weights=None,
860
+ metrics_collections=None,
861
+ updates_collections=None,
862
+ name=None):
863
+ """Computes the mean absolute error between the labels and predictions.
864
+
865
+ The `mean_absolute_error` function creates two local variables,
866
+ `total` and `count` that are used to compute the mean absolute error. This
867
+ average is weighted by `weights`, and it is ultimately returned as
868
+ `mean_absolute_error`: an idempotent operation that simply divides `total` by
869
+ `count`.
870
+
871
+ For estimation of the metric over a stream of data, the function creates an
872
+ `update_op` operation that updates these variables and returns the
873
+ `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
874
+ absolute value of the differences between `predictions` and `labels`. Then
875
+ `update_op` increments `total` with the reduced sum of the product of
876
+ `weights` and `absolute_errors`, and it increments `count` with the reduced
877
+ sum of `weights`
878
+
879
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
880
+
881
+ Args:
882
+ labels: A `Tensor` of the same shape as `predictions`.
883
+ predictions: A `Tensor` of arbitrary shape.
884
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
885
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
886
+ be either `1`, or the same as the corresponding `labels` dimension).
887
+ metrics_collections: An optional list of collections that
888
+ `mean_absolute_error` should be added to.
889
+ updates_collections: An optional list of collections that `update_op` should
890
+ be added to.
891
+ name: An optional variable_scope name.
892
+
893
+ Returns:
894
+ mean_absolute_error: A `Tensor` representing the current mean, the value of
895
+ `total` divided by `count`.
896
+ update_op: An operation that increments the `total` and `count` variables
897
+ appropriately and whose value matches `mean_absolute_error`.
898
+
899
+ Raises:
900
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
901
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
902
+ either `metrics_collections` or `updates_collections` are not a list or
903
+ tuple.
904
+ RuntimeError: If eager execution is enabled.
905
+ """
906
+ if context.executing_eagerly():
907
+ raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
908
+ 'when eager execution is enabled.')
909
+
910
+ predictions, labels, weights = _remove_squeezable_dimensions(
911
+ predictions=predictions, labels=labels, weights=weights)
912
+ absolute_errors = math_ops.abs(predictions - labels)
913
+ return mean(absolute_errors, weights, metrics_collections,
914
+ updates_collections, name or 'mean_absolute_error')
915
+
916
+
917
+ @tf_export('metrics.mean_cosine_distance')
918
+ def mean_cosine_distance(labels,
919
+ predictions,
920
+ dim,
921
+ weights=None,
922
+ metrics_collections=None,
923
+ updates_collections=None,
924
+ name=None):
925
+ """Computes the cosine distance between the labels and predictions.
926
+
927
+ The `mean_cosine_distance` function creates two local variables,
928
+ `total` and `count` that are used to compute the average cosine distance
929
+ between `predictions` and `labels`. This average is weighted by `weights`,
930
+ and it is ultimately returned as `mean_distance`, which is an idempotent
931
+ operation that simply divides `total` by `count`.
932
+
933
+ For estimation of the metric over a stream of data, the function creates an
934
+ `update_op` operation that updates these variables and returns the
935
+ `mean_distance`.
936
+
937
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
938
+
939
+ Args:
940
+ labels: A `Tensor` of arbitrary shape.
941
+ predictions: A `Tensor` of the same shape as `labels`.
942
+ dim: The dimension along which the cosine distance is computed.
943
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
944
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
945
+ be either `1`, or the same as the corresponding `labels` dimension). Also,
946
+ dimension `dim` must be `1`.
947
+ metrics_collections: An optional list of collections that the metric
948
+ value variable should be added to.
949
+ updates_collections: An optional list of collections that the metric update
950
+ ops should be added to.
951
+ name: An optional variable_scope name.
952
+
953
+ Returns:
954
+ mean_distance: A `Tensor` representing the current mean, the value of
955
+ `total` divided by `count`.
956
+ update_op: An operation that increments the `total` and `count` variables
957
+ appropriately.
958
+
959
+ Raises:
960
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
961
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
962
+ either `metrics_collections` or `updates_collections` are not a list or
963
+ tuple.
964
+ RuntimeError: If eager execution is enabled.
965
+ """
966
+ if context.executing_eagerly():
967
+ raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
968
+ 'eager execution is enabled.')
969
+
970
+ predictions, labels, weights = _remove_squeezable_dimensions(
971
+ predictions=predictions, labels=labels, weights=weights)
972
+ radial_diffs = math_ops.multiply(predictions, labels)
973
+ radial_diffs = math_ops.reduce_sum(
974
+ radial_diffs, reduction_indices=[
975
+ dim,
976
+ ], keepdims=True)
977
+ mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
978
+ 'mean_cosine_distance')
979
+ mean_distance = math_ops.subtract(1.0, mean_distance)
980
+ update_op = math_ops.subtract(1.0, update_op)
981
+
982
+ if metrics_collections:
983
+ ops.add_to_collections(metrics_collections, mean_distance)
984
+
985
+ if updates_collections:
986
+ ops.add_to_collections(updates_collections, update_op)
987
+
988
+ return mean_distance, update_op
989
+
990
+
991
+ @tf_export('metrics.mean_per_class_accuracy')
992
+ def mean_per_class_accuracy(labels,
993
+ predictions,
994
+ num_classes,
995
+ weights=None,
996
+ metrics_collections=None,
997
+ updates_collections=None,
998
+ name=None):
999
+ """Calculates the mean of the per-class accuracies.
1000
+
1001
+ Calculates the accuracy for each class, then takes the mean of that.
1002
+
1003
+ For estimation of the metric over a stream of data, the function creates an
1004
+ `update_op` operation that updates the accuracy of each class and returns
1005
+ them.
1006
+
1007
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1008
+
1009
+ Args:
1010
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
1011
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1012
+ predictions: A `Tensor` of prediction results for semantic labels, whose
1013
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
1014
+ flattened if its rank > 1.
1015
+ num_classes: The possible number of labels the prediction task can
1016
+ have. This value must be provided, since two variables with shape =
1017
+ [num_classes] will be allocated.
1018
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1019
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1020
+ be either `1`, or the same as the corresponding `labels` dimension).
1021
+ metrics_collections: An optional list of collections that
1022
+ `mean_per_class_accuracy'
1023
+ should be added to.
1024
+ updates_collections: An optional list of collections `update_op` should be
1025
+ added to.
1026
+ name: An optional variable_scope name.
1027
+
1028
+ Returns:
1029
+ mean_accuracy: A `Tensor` representing the mean per class accuracy.
1030
+ update_op: An operation that updates the accuracy tensor.
1031
+
1032
+ Raises:
1033
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1034
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1035
+ either `metrics_collections` or `updates_collections` are not a list or
1036
+ tuple.
1037
+ RuntimeError: If eager execution is enabled.
1038
+ """
1039
+ if context.executing_eagerly():
1040
+ raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1041
+ 'when eager execution is enabled.')
1042
+
1043
+ with variable_scope.variable_scope(name, 'mean_accuracy',
1044
+ (predictions, labels, weights)):
1045
+ labels = math_ops.to_int64(labels)
1046
+
1047
+ # Flatten the input if its rank > 1.
1048
+ if labels.get_shape().ndims > 1:
1049
+ labels = array_ops.reshape(labels, [-1])
1050
+
1051
+ if predictions.get_shape().ndims > 1:
1052
+ predictions = array_ops.reshape(predictions, [-1])
1053
+
1054
+ # Check if shape is compatible.
1055
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1056
+
1057
+ total = metric_variable([num_classes], dtypes.float32, name='total')
1058
+ count = metric_variable([num_classes], dtypes.float32, name='count')
1059
+
1060
+ ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1061
+
1062
+ if labels.dtype != predictions.dtype:
1063
+ predictions = math_ops.cast(predictions, labels.dtype)
1064
+ is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
1065
+
1066
+ if weights is not None:
1067
+ if weights.get_shape().ndims > 1:
1068
+ weights = array_ops.reshape(weights, [-1])
1069
+ weights = math_ops.to_float(weights)
1070
+
1071
+ is_correct *= weights
1072
+ ones *= weights
1073
+
1074
+ update_total_op = state_ops.scatter_add(
1075
+ total, labels, ones, use_locking=True)
1076
+ update_count_op = state_ops.scatter_add(
1077
+ count, labels, is_correct, use_locking=True)
1078
+
1079
+ def compute_mean_accuracy(_, count, total):
1080
+ per_class_accuracy = _safe_div(count, total, None)
1081
+ mean_accuracy_v = math_ops.reduce_mean(
1082
+ per_class_accuracy, name='mean_accuracy')
1083
+ return mean_accuracy_v
1084
+
1085
+ mean_accuracy_v = _aggregate_across_towers(metrics_collections,
1086
+ compute_mean_accuracy, count,
1087
+ total)
1088
+
1089
+ update_op = _safe_div(update_count_op, update_total_op, name='update_op')
1090
+ if updates_collections:
1091
+ ops.add_to_collections(updates_collections, update_op)
1092
+
1093
+ return mean_accuracy_v, update_op
1094
+
1095
+
1096
+ @tf_export('metrics.mean_iou')
1097
+ def mean_iou(labels,
1098
+ predictions,
1099
+ num_classes,
1100
+ weights=None,
1101
+ metrics_collections=None,
1102
+ updates_collections=None,
1103
+ name=None):
1104
+ """Calculate per-step mean Intersection-Over-Union (mIOU).
1105
+
1106
+ Mean Intersection-Over-Union is a common evaluation metric for
1107
+ semantic image segmentation, which first computes the IOU for each
1108
+ semantic class and then computes the average over classes.
1109
+ IOU is defined as follows:
1110
+ IOU = true_positive / (true_positive + false_positive + false_negative).
1111
+ The predictions are accumulated in a confusion matrix, weighted by `weights`,
1112
+ and mIOU is then calculated from it.
1113
+
1114
+ For estimation of the metric over a stream of data, the function creates an
1115
+ `update_op` operation that updates these variables and returns the `mean_iou`.
1116
+
1117
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1118
+
1119
+ Args:
1120
+ labels: A `Tensor` of ground truth labels with shape [batch size] and of
1121
+ type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1122
+ predictions: A `Tensor` of prediction results for semantic labels, whose
1123
+ shape is [batch size] and type `int32` or `int64`. The tensor will be
1124
+ flattened if its rank > 1.
1125
+ num_classes: The possible number of labels the prediction task can
1126
+ have. This value must be provided, since a confusion matrix of
1127
+ dimension = [num_classes, num_classes] will be allocated.
1128
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1129
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1130
+ be either `1`, or the same as the corresponding `labels` dimension).
1131
+ metrics_collections: An optional list of collections that `mean_iou`
1132
+ should be added to.
1133
+ updates_collections: An optional list of collections `update_op` should be
1134
+ added to.
1135
+ name: An optional variable_scope name.
1136
+
1137
+ Returns:
1138
+ mean_iou: A `Tensor` representing the mean intersection-over-union.
1139
+ update_op: An operation that increments the confusion matrix.
1140
+
1141
+ Raises:
1142
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1143
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1144
+ either `metrics_collections` or `updates_collections` are not a list or
1145
+ tuple.
1146
+ RuntimeError: If eager execution is enabled.
1147
+ """
1148
+ if context.executing_eagerly():
1149
+ raise RuntimeError('tf.metrics.mean_iou is not supported when '
1150
+ 'eager execution is enabled.')
1151
+
1152
+ with variable_scope.variable_scope(name, 'mean_iou',
1153
+ (predictions, labels, weights)):
1154
+ # Check if shape is compatible.
1155
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1156
+
1157
+ total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1158
+ num_classes, weights)
1159
+
1160
+ def compute_mean_iou(_, total_cm):
1161
+ """Compute the mean intersection-over-union via the confusion matrix."""
1162
+ sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
1163
+ sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
1164
+ cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
1165
+ denominator = sum_over_row + sum_over_col - cm_diag
1166
+
1167
+ # The mean is only computed over classes that appear in the
1168
+ # label or prediction tensor. If the denominator is 0, we need to
1169
+ # ignore the class.
1170
+ num_valid_entries = math_ops.reduce_sum(
1171
+ math_ops.cast(
1172
+ math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1173
+
1174
+ # If the value of the denominator is 0, set it to 1 to avoid
1175
+ # zero division.
1176
+ denominator = array_ops.where(
1177
+ math_ops.greater(denominator, 0), denominator,
1178
+ array_ops.ones_like(denominator))
1179
+ iou = math_ops.div(cm_diag, denominator)
1180
+
1181
+ # If the number of valid entries is 0 (no classes) we return 0.
1182
+ result = array_ops.where(
1183
+ math_ops.greater(num_valid_entries, 0),
1184
+ math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1185
+ return result
1186
+
1187
+ # TODO(priyag): Use outside_compilation if in TPU context.
1188
+ mean_iou_v = _aggregate_across_towers(metrics_collections, compute_mean_iou,
1189
+ total_cm)
1190
+
1191
+ if updates_collections:
1192
+ ops.add_to_collections(updates_collections, update_op)
1193
+
1194
+ return mean_iou_v, update_op
1195
+
1196
+
1197
+ @tf_export('metrics.mean_relative_error')
1198
+ def mean_relative_error(labels,
1199
+ predictions,
1200
+ normalizer,
1201
+ weights=None,
1202
+ metrics_collections=None,
1203
+ updates_collections=None,
1204
+ name=None):
1205
+ """Computes the mean relative error by normalizing with the given values.
1206
+
1207
+ The `mean_relative_error` function creates two local variables,
1208
+ `total` and `count` that are used to compute the mean relative absolute error.
1209
+ This average is weighted by `weights`, and it is ultimately returned as
1210
+ `mean_relative_error`: an idempotent operation that simply divides `total` by
1211
+ `count`.
1212
+
1213
+ For estimation of the metric over a stream of data, the function creates an
1214
+ `update_op` operation that updates these variables and returns the
1215
+ `mean_reative_error`. Internally, a `relative_errors` operation divides the
1216
+ absolute value of the differences between `predictions` and `labels` by the
1217
+ `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1218
+ product of `weights` and `relative_errors`, and it increments `count` with the
1219
+ reduced sum of `weights`.
1220
+
1221
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1222
+
1223
+ Args:
1224
+ labels: A `Tensor` of the same shape as `predictions`.
1225
+ predictions: A `Tensor` of arbitrary shape.
1226
+ normalizer: A `Tensor` of the same shape as `predictions`.
1227
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1228
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1229
+ be either `1`, or the same as the corresponding `labels` dimension).
1230
+ metrics_collections: An optional list of collections that
1231
+ `mean_relative_error` should be added to.
1232
+ updates_collections: An optional list of collections that `update_op` should
1233
+ be added to.
1234
+ name: An optional variable_scope name.
1235
+
1236
+ Returns:
1237
+ mean_relative_error: A `Tensor` representing the current mean, the value of
1238
+ `total` divided by `count`.
1239
+ update_op: An operation that increments the `total` and `count` variables
1240
+ appropriately and whose value matches `mean_relative_error`.
1241
+
1242
+ Raises:
1243
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1244
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1245
+ either `metrics_collections` or `updates_collections` are not a list or
1246
+ tuple.
1247
+ RuntimeError: If eager execution is enabled.
1248
+ """
1249
+ if context.executing_eagerly():
1250
+ raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1251
+ 'eager execution is enabled.')
1252
+
1253
+ predictions, labels, weights = _remove_squeezable_dimensions(
1254
+ predictions=predictions, labels=labels, weights=weights)
1255
+
1256
+ predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1257
+ predictions, normalizer)
1258
+ predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1259
+ relative_errors = array_ops.where(
1260
+ math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1261
+ math_ops.div(math_ops.abs(labels - predictions), normalizer))
1262
+ return mean(relative_errors, weights, metrics_collections,
1263
+ updates_collections, name or 'mean_relative_error')
1264
+
1265
+
1266
+ @tf_export('metrics.mean_squared_error')
1267
+ def mean_squared_error(labels,
1268
+ predictions,
1269
+ weights=None,
1270
+ metrics_collections=None,
1271
+ updates_collections=None,
1272
+ name=None):
1273
+ """Computes the mean squared error between the labels and predictions.
1274
+
1275
+ The `mean_squared_error` function creates two local variables,
1276
+ `total` and `count` that are used to compute the mean squared error.
1277
+ This average is weighted by `weights`, and it is ultimately returned as
1278
+ `mean_squared_error`: an idempotent operation that simply divides `total` by
1279
+ `count`.
1280
+
1281
+ For estimation of the metric over a stream of data, the function creates an
1282
+ `update_op` operation that updates these variables and returns the
1283
+ `mean_squared_error`. Internally, a `squared_error` operation computes the
1284
+ element-wise square of the difference between `predictions` and `labels`. Then
1285
+ `update_op` increments `total` with the reduced sum of the product of
1286
+ `weights` and `squared_error`, and it increments `count` with the reduced sum
1287
+ of `weights`.
1288
+
1289
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1290
+
1291
+ Args:
1292
+ labels: A `Tensor` of the same shape as `predictions`.
1293
+ predictions: A `Tensor` of arbitrary shape.
1294
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1295
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1296
+ be either `1`, or the same as the corresponding `labels` dimension).
1297
+ metrics_collections: An optional list of collections that
1298
+ `mean_squared_error` should be added to.
1299
+ updates_collections: An optional list of collections that `update_op` should
1300
+ be added to.
1301
+ name: An optional variable_scope name.
1302
+
1303
+ Returns:
1304
+ mean_squared_error: A `Tensor` representing the current mean, the value of
1305
+ `total` divided by `count`.
1306
+ update_op: An operation that increments the `total` and `count` variables
1307
+ appropriately and whose value matches `mean_squared_error`.
1308
+
1309
+ Raises:
1310
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1311
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1312
+ either `metrics_collections` or `updates_collections` are not a list or
1313
+ tuple.
1314
+ RuntimeError: If eager execution is enabled.
1315
+ """
1316
+ if context.executing_eagerly():
1317
+ raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1318
+ 'eager execution is enabled.')
1319
+
1320
+ predictions, labels, weights = _remove_squeezable_dimensions(
1321
+ predictions=predictions, labels=labels, weights=weights)
1322
+ squared_error = math_ops.square(labels - predictions)
1323
+ return mean(squared_error, weights, metrics_collections, updates_collections,
1324
+ name or 'mean_squared_error')
1325
+
1326
+
1327
+ @tf_export('metrics.mean_tensor')
1328
+ def mean_tensor(values,
1329
+ weights=None,
1330
+ metrics_collections=None,
1331
+ updates_collections=None,
1332
+ name=None):
1333
+ """Computes the element-wise (weighted) mean of the given tensors.
1334
+
1335
+ In contrast to the `mean` function which returns a scalar with the
1336
+ mean, this function returns an average tensor with the same shape as the
1337
+ input tensors.
1338
+
1339
+ The `mean_tensor` function creates two local variables,
1340
+ `total_tensor` and `count_tensor` that are used to compute the average of
1341
+ `values`. This average is ultimately returned as `mean` which is an idempotent
1342
+ operation that simply divides `total` by `count`.
1343
+
1344
+ For estimation of the metric over a stream of data, the function creates an
1345
+ `update_op` operation that updates these variables and returns the `mean`.
1346
+ `update_op` increments `total` with the reduced sum of the product of `values`
1347
+ and `weights`, and it increments `count` with the reduced sum of `weights`.
1348
+
1349
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1350
+
1351
+ Args:
1352
+ values: A `Tensor` of arbitrary dimensions.
1353
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1354
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1355
+ be either `1`, or the same as the corresponding `values` dimension).
1356
+ metrics_collections: An optional list of collections that `mean`
1357
+ should be added to.
1358
+ updates_collections: An optional list of collections that `update_op`
1359
+ should be added to.
1360
+ name: An optional variable_scope name.
1361
+
1362
+ Returns:
1363
+ mean: A float `Tensor` representing the current mean, the value of `total`
1364
+ divided by `count`.
1365
+ update_op: An operation that increments the `total` and `count` variables
1366
+ appropriately and whose value matches `mean_value`.
1367
+
1368
+ Raises:
1369
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1370
+ or if either `metrics_collections` or `updates_collections` are not a list
1371
+ or tuple.
1372
+ RuntimeError: If eager execution is enabled.
1373
+ """
1374
+ if context.executing_eagerly():
1375
+ raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1376
+ 'eager execution is enabled.')
1377
+
1378
+ with variable_scope.variable_scope(name, 'mean', (values, weights)):
1379
+ values = math_ops.to_float(values)
1380
+ total = metric_variable(
1381
+ values.get_shape(), dtypes.float32, name='total_tensor')
1382
+ count = metric_variable(
1383
+ values.get_shape(), dtypes.float32, name='count_tensor')
1384
+
1385
+ num_values = array_ops.ones_like(values)
1386
+ if weights is not None:
1387
+ values, _, weights = _remove_squeezable_dimensions(
1388
+ predictions=values, labels=None, weights=weights)
1389
+ weights = weights_broadcast_ops.broadcast_weights(
1390
+ math_ops.to_float(weights), values)
1391
+ values = math_ops.multiply(values, weights)
1392
+ num_values = math_ops.multiply(num_values, weights)
1393
+
1394
+ update_total_op = state_ops.assign_add(total, values, use_locking=True)
1395
+ with ops.control_dependencies([values]):
1396
+ update_count_op = state_ops.assign_add(
1397
+ count, num_values, use_locking=True)
1398
+
1399
+ compute_mean = lambda _, t, c: _safe_div(t, c, 'value') # noqa: E731
1400
+
1401
+ mean_t = _aggregate_across_towers(metrics_collections, compute_mean, total,
1402
+ count)
1403
+
1404
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
1405
+ if updates_collections:
1406
+ ops.add_to_collections(updates_collections, update_op)
1407
+
1408
+ return mean_t, update_op
1409
+
1410
+
1411
+ @tf_export('metrics.percentage_below')
1412
+ def percentage_below(values,
1413
+ threshold,
1414
+ weights=None,
1415
+ metrics_collections=None,
1416
+ updates_collections=None,
1417
+ name=None):
1418
+ """Computes the percentage of values less than the given threshold.
1419
+
1420
+ The `percentage_below` function creates two local variables,
1421
+ `total` and `count` that are used to compute the percentage of `values` that
1422
+ fall below `threshold`. This rate is weighted by `weights`, and it is
1423
+ ultimately returned as `percentage` which is an idempotent operation that
1424
+ simply divides `total` by `count`.
1425
+
1426
+ For estimation of the metric over a stream of data, the function creates an
1427
+ `update_op` operation that updates these variables and returns the
1428
+ `percentage`.
1429
+
1430
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1431
+
1432
+ Args:
1433
+ values: A numeric `Tensor` of arbitrary size.
1434
+ threshold: A scalar threshold.
1435
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1436
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1437
+ be either `1`, or the same as the corresponding `values` dimension).
1438
+ metrics_collections: An optional list of collections that the metric
1439
+ value variable should be added to.
1440
+ updates_collections: An optional list of collections that the metric update
1441
+ ops should be added to.
1442
+ name: An optional variable_scope name.
1443
+
1444
+ Returns:
1445
+ percentage: A `Tensor` representing the current mean, the value of `total`
1446
+ divided by `count`.
1447
+ update_op: An operation that increments the `total` and `count` variables
1448
+ appropriately.
1449
+
1450
+ Raises:
1451
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1452
+ or if either `metrics_collections` or `updates_collections` are not a list
1453
+ or tuple.
1454
+ RuntimeError: If eager execution is enabled.
1455
+ """
1456
+ if context.executing_eagerly():
1457
+ raise RuntimeError('tf.metrics.percentage_below is not supported when '
1458
+ 'eager execution is enabled.')
1459
+
1460
+ is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
1461
+ return mean(is_below_threshold, weights, metrics_collections,
1462
+ updates_collections, name or 'percentage_below_threshold')
1463
+
1464
+
1465
+ def _count_condition(values,
1466
+ weights=None,
1467
+ metrics_collections=None,
1468
+ updates_collections=None):
1469
+ """Sums the weights of cases where the given values are True.
1470
+
1471
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1472
+
1473
+ Args:
1474
+ values: A `bool` `Tensor` of arbitrary size.
1475
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1476
+ `values`, and must be broadcastable to `values` (i.e., all dimensions must
1477
+ be either `1`, or the same as the corresponding `values` dimension).
1478
+ metrics_collections: An optional list of collections that the metric
1479
+ value variable should be added to.
1480
+ updates_collections: An optional list of collections that the metric update
1481
+ ops should be added to.
1482
+
1483
+ Returns:
1484
+ value_tensor: A `Tensor` representing the current value of the metric.
1485
+ update_op: An operation that accumulates the error from a batch of data.
1486
+
1487
+ Raises:
1488
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1489
+ or if either `metrics_collections` or `updates_collections` are not a list
1490
+ or tuple.
1491
+ """
1492
+ check_ops.assert_type(values, dtypes.bool)
1493
+ count = metric_variable([], dtypes.float32, name='count')
1494
+
1495
+ values = math_ops.to_float(values)
1496
+ if weights is not None:
1497
+ with ops.control_dependencies(
1498
+ (check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
1499
+ weights = math_ops.to_float(weights)
1500
+ values = math_ops.multiply(values, weights)
1501
+
1502
+ value_tensor = _aggregate_variable(count, metrics_collections)
1503
+
1504
+ update_op = state_ops.assign_add(
1505
+ count, math_ops.reduce_sum(values), use_locking=True)
1506
+ if updates_collections:
1507
+ ops.add_to_collections(updates_collections, update_op)
1508
+
1509
+ return value_tensor, update_op
1510
+
1511
+
1512
+ @tf_export('metrics.false_negatives')
1513
+ def false_negatives(labels,
1514
+ predictions,
1515
+ weights=None,
1516
+ metrics_collections=None,
1517
+ updates_collections=None,
1518
+ name=None):
1519
+ """Computes the total number of false negatives.
1520
+
1521
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1522
+
1523
+ Args:
1524
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1525
+ `predictions`. Will be cast to `bool`.
1526
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1527
+ be cast to `bool`.
1528
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1529
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1530
+ be either `1`, or the same as the corresponding `labels` dimension).
1531
+ metrics_collections: An optional list of collections that the metric
1532
+ value variable should be added to.
1533
+ updates_collections: An optional list of collections that the metric update
1534
+ ops should be added to.
1535
+ name: An optional variable_scope name.
1536
+
1537
+ Returns:
1538
+ value_tensor: A `Tensor` representing the current value of the metric.
1539
+ update_op: An operation that accumulates the error from a batch of data.
1540
+
1541
+ Raises:
1542
+ ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1543
+ or if either `metrics_collections` or `updates_collections` are not a list
1544
+ or tuple.
1545
+ RuntimeError: If eager execution is enabled.
1546
+ """
1547
+ if context.executing_eagerly():
1548
+ raise RuntimeError('tf.metrics.false_negatives is not supported when '
1549
+ 'eager execution is enabled.')
1550
+
1551
+ with variable_scope.variable_scope(name, 'false_negatives',
1552
+ (predictions, labels, weights)):
1553
+
1554
+ predictions, labels, weights = _remove_squeezable_dimensions(
1555
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1556
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1557
+ weights=weights)
1558
+ is_false_negative = math_ops.logical_and(
1559
+ math_ops.equal(labels, True), math_ops.equal(predictions, False))
1560
+ return _count_condition(is_false_negative, weights, metrics_collections,
1561
+ updates_collections)
1562
+
1563
+
1564
+ @tf_export('metrics.false_negatives_at_thresholds')
1565
+ def false_negatives_at_thresholds(labels,
1566
+ predictions,
1567
+ thresholds,
1568
+ weights=None,
1569
+ metrics_collections=None,
1570
+ updates_collections=None,
1571
+ name=None):
1572
+ """Computes false negatives at provided threshold values.
1573
+
1574
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1575
+
1576
+ Args:
1577
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1578
+ `bool`.
1579
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1580
+ are in the range `[0, 1]`.
1581
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1582
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1583
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1584
+ be either `1`, or the same as the corresponding `labels` dimension).
1585
+ metrics_collections: An optional list of collections that `false_negatives`
1586
+ should be added to.
1587
+ updates_collections: An optional list of collections that `update_op` should
1588
+ be added to.
1589
+ name: An optional variable_scope name.
1590
+
1591
+ Returns:
1592
+ false_negatives: A float `Tensor` of shape `[len(thresholds)]`.
1593
+ update_op: An operation that updates the `false_negatives` variable and
1594
+ returns its current value.
1595
+
1596
+ Raises:
1597
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1598
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1599
+ either `metrics_collections` or `updates_collections` are not a list or
1600
+ tuple.
1601
+ RuntimeError: If eager execution is enabled.
1602
+ """
1603
+ if context.executing_eagerly():
1604
+ raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1605
+ 'supported when eager execution is enabled.')
1606
+
1607
+ with variable_scope.variable_scope(name, 'false_negatives',
1608
+ (predictions, labels, weights)):
1609
+ values, update_ops = _confusion_matrix_at_thresholds(
1610
+ labels, predictions, thresholds, weights=weights, includes=('fn',))
1611
+
1612
+ fn_value = _aggregate_variable(values['fn'], metrics_collections)
1613
+
1614
+ if updates_collections:
1615
+ ops.add_to_collections(updates_collections, update_ops['fn'])
1616
+
1617
+ return fn_value, update_ops['fn']
1618
+
1619
+
1620
+ @tf_export('metrics.false_positives')
1621
+ def false_positives(labels,
1622
+ predictions,
1623
+ weights=None,
1624
+ metrics_collections=None,
1625
+ updates_collections=None,
1626
+ name=None):
1627
+ """Sum the weights of false positives.
1628
+
1629
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1630
+
1631
+ Args:
1632
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1633
+ `predictions`. Will be cast to `bool`.
1634
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1635
+ be cast to `bool`.
1636
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1637
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1638
+ be either `1`, or the same as the corresponding `labels` dimension).
1639
+ metrics_collections: An optional list of collections that the metric
1640
+ value variable should be added to.
1641
+ updates_collections: An optional list of collections that the metric update
1642
+ ops should be added to.
1643
+ name: An optional variable_scope name.
1644
+
1645
+ Returns:
1646
+ value_tensor: A `Tensor` representing the current value of the metric.
1647
+ update_op: An operation that accumulates the error from a batch of data.
1648
+
1649
+ Raises:
1650
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1651
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1652
+ either `metrics_collections` or `updates_collections` are not a list or
1653
+ tuple.
1654
+ RuntimeError: If eager execution is enabled.
1655
+ """
1656
+ if context.executing_eagerly():
1657
+ raise RuntimeError('tf.metrics.false_positives is not supported when '
1658
+ 'eager execution is enabled.')
1659
+
1660
+ with variable_scope.variable_scope(name, 'false_positives',
1661
+ (predictions, labels, weights)):
1662
+
1663
+ predictions, labels, weights = _remove_squeezable_dimensions(
1664
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1665
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1666
+ weights=weights)
1667
+ is_false_positive = math_ops.logical_and(
1668
+ math_ops.equal(labels, False), math_ops.equal(predictions, True))
1669
+ return _count_condition(is_false_positive, weights, metrics_collections,
1670
+ updates_collections)
1671
+
1672
+
1673
+ @tf_export('metrics.false_positives_at_thresholds')
1674
+ def false_positives_at_thresholds(labels,
1675
+ predictions,
1676
+ thresholds,
1677
+ weights=None,
1678
+ metrics_collections=None,
1679
+ updates_collections=None,
1680
+ name=None):
1681
+ """Computes false positives at provided threshold values.
1682
+
1683
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1684
+
1685
+ Args:
1686
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1687
+ `bool`.
1688
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1689
+ are in the range `[0, 1]`.
1690
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1691
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1692
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1693
+ be either `1`, or the same as the corresponding `labels` dimension).
1694
+ metrics_collections: An optional list of collections that `false_positives`
1695
+ should be added to.
1696
+ updates_collections: An optional list of collections that `update_op` should
1697
+ be added to.
1698
+ name: An optional variable_scope name.
1699
+
1700
+ Returns:
1701
+ false_positives: A float `Tensor` of shape `[len(thresholds)]`.
1702
+ update_op: An operation that updates the `false_positives` variable and
1703
+ returns its current value.
1704
+
1705
+ Raises:
1706
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1707
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1708
+ either `metrics_collections` or `updates_collections` are not a list or
1709
+ tuple.
1710
+ RuntimeError: If eager execution is enabled.
1711
+ """
1712
+ if context.executing_eagerly():
1713
+ raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1714
+ 'supported when eager execution is enabled.')
1715
+
1716
+ with variable_scope.variable_scope(name, 'false_positives',
1717
+ (predictions, labels, weights)):
1718
+ values, update_ops = _confusion_matrix_at_thresholds(
1719
+ labels, predictions, thresholds, weights=weights, includes=('fp',))
1720
+
1721
+ fp_value = _aggregate_variable(values['fp'], metrics_collections)
1722
+
1723
+ if updates_collections:
1724
+ ops.add_to_collections(updates_collections, update_ops['fp'])
1725
+
1726
+ return fp_value, update_ops['fp']
1727
+
1728
+
1729
+ @tf_export('metrics.true_negatives')
1730
+ def true_negatives(labels,
1731
+ predictions,
1732
+ weights=None,
1733
+ metrics_collections=None,
1734
+ updates_collections=None,
1735
+ name=None):
1736
+ """Sum the weights of true_negatives.
1737
+
1738
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1739
+
1740
+ Args:
1741
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1742
+ `predictions`. Will be cast to `bool`.
1743
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1744
+ be cast to `bool`.
1745
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1746
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1747
+ be either `1`, or the same as the corresponding `labels` dimension).
1748
+ metrics_collections: An optional list of collections that the metric
1749
+ value variable should be added to.
1750
+ updates_collections: An optional list of collections that the metric update
1751
+ ops should be added to.
1752
+ name: An optional variable_scope name.
1753
+
1754
+ Returns:
1755
+ value_tensor: A `Tensor` representing the current value of the metric.
1756
+ update_op: An operation that accumulates the error from a batch of data.
1757
+
1758
+ Raises:
1759
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1760
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1761
+ either `metrics_collections` or `updates_collections` are not a list or
1762
+ tuple.
1763
+ RuntimeError: If eager execution is enabled.
1764
+ """
1765
+ if context.executing_eagerly():
1766
+ raise RuntimeError('tf.metrics.true_negatives is not '
1767
+ 'supported when eager execution is enabled.')
1768
+
1769
+ with variable_scope.variable_scope(name, 'true_negatives',
1770
+ (predictions, labels, weights)):
1771
+
1772
+ predictions, labels, weights = _remove_squeezable_dimensions(
1773
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1774
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1775
+ weights=weights)
1776
+ is_true_negative = math_ops.logical_and(
1777
+ math_ops.equal(labels, False), math_ops.equal(predictions, False))
1778
+ return _count_condition(is_true_negative, weights, metrics_collections,
1779
+ updates_collections)
1780
+
1781
+
1782
+ @tf_export('metrics.true_negatives_at_thresholds')
1783
+ def true_negatives_at_thresholds(labels,
1784
+ predictions,
1785
+ thresholds,
1786
+ weights=None,
1787
+ metrics_collections=None,
1788
+ updates_collections=None,
1789
+ name=None):
1790
+ """Computes true negatives at provided threshold values.
1791
+
1792
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1793
+
1794
+ Args:
1795
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1796
+ `bool`.
1797
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1798
+ are in the range `[0, 1]`.
1799
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1800
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1801
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1802
+ be either `1`, or the same as the corresponding `labels` dimension).
1803
+ metrics_collections: An optional list of collections that `true_negatives`
1804
+ should be added to.
1805
+ updates_collections: An optional list of collections that `update_op` should
1806
+ be added to.
1807
+ name: An optional variable_scope name.
1808
+
1809
+ Returns:
1810
+ true_negatives: A float `Tensor` of shape `[len(thresholds)]`.
1811
+ update_op: An operation that updates the `true_negatives` variable and
1812
+ returns its current value.
1813
+
1814
+ Raises:
1815
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1816
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1817
+ either `metrics_collections` or `updates_collections` are not a list or
1818
+ tuple.
1819
+ RuntimeError: If eager execution is enabled.
1820
+ """
1821
+ if context.executing_eagerly():
1822
+ raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
1823
+ 'supported when eager execution is enabled.')
1824
+
1825
+ with variable_scope.variable_scope(name, 'true_negatives',
1826
+ (predictions, labels, weights)):
1827
+ values, update_ops = _confusion_matrix_at_thresholds(
1828
+ labels, predictions, thresholds, weights=weights, includes=('tn',))
1829
+
1830
+ tn_value = _aggregate_variable(values['tn'], metrics_collections)
1831
+
1832
+ if updates_collections:
1833
+ ops.add_to_collections(updates_collections, update_ops['tn'])
1834
+
1835
+ return tn_value, update_ops['tn']
1836
+
1837
+
1838
+ @tf_export('metrics.true_positives')
1839
+ def true_positives(labels,
1840
+ predictions,
1841
+ weights=None,
1842
+ metrics_collections=None,
1843
+ updates_collections=None,
1844
+ name=None):
1845
+ """Sum the weights of true_positives.
1846
+
1847
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1848
+
1849
+ Args:
1850
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1851
+ `predictions`. Will be cast to `bool`.
1852
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1853
+ be cast to `bool`.
1854
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1855
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1856
+ be either `1`, or the same as the corresponding `labels` dimension).
1857
+ metrics_collections: An optional list of collections that the metric
1858
+ value variable should be added to.
1859
+ updates_collections: An optional list of collections that the metric update
1860
+ ops should be added to.
1861
+ name: An optional variable_scope name.
1862
+
1863
+ Returns:
1864
+ value_tensor: A `Tensor` representing the current value of the metric.
1865
+ update_op: An operation that accumulates the error from a batch of data.
1866
+
1867
+ Raises:
1868
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1869
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1870
+ either `metrics_collections` or `updates_collections` are not a list or
1871
+ tuple.
1872
+ RuntimeError: If eager execution is enabled.
1873
+ """
1874
+ if context.executing_eagerly():
1875
+ raise RuntimeError('tf.metrics.true_positives is not '
1876
+ 'supported when eager execution is enabled.')
1877
+
1878
+ with variable_scope.variable_scope(name, 'true_positives',
1879
+ (predictions, labels, weights)):
1880
+
1881
+ predictions, labels, weights = _remove_squeezable_dimensions(
1882
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1883
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
1884
+ weights=weights)
1885
+ is_true_positive = math_ops.logical_and(
1886
+ math_ops.equal(labels, True), math_ops.equal(predictions, True))
1887
+ return _count_condition(is_true_positive, weights, metrics_collections,
1888
+ updates_collections)
1889
+
1890
+
1891
+ @tf_export('metrics.true_positives_at_thresholds')
1892
+ def true_positives_at_thresholds(labels,
1893
+ predictions,
1894
+ thresholds,
1895
+ weights=None,
1896
+ metrics_collections=None,
1897
+ updates_collections=None,
1898
+ name=None):
1899
+ """Computes true positives at provided threshold values.
1900
+
1901
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1902
+
1903
+ Args:
1904
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1905
+ `bool`.
1906
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
1907
+ are in the range `[0, 1]`.
1908
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1909
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1910
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1911
+ be either `1`, or the same as the corresponding `labels` dimension).
1912
+ metrics_collections: An optional list of collections that `true_positives`
1913
+ should be added to.
1914
+ updates_collections: An optional list of collections that `update_op` should
1915
+ be added to.
1916
+ name: An optional variable_scope name.
1917
+
1918
+ Returns:
1919
+ true_positives: A float `Tensor` of shape `[len(thresholds)]`.
1920
+ update_op: An operation that updates the `true_positives` variable and
1921
+ returns its current value.
1922
+
1923
+ Raises:
1924
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1925
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1926
+ either `metrics_collections` or `updates_collections` are not a list or
1927
+ tuple.
1928
+ RuntimeError: If eager execution is enabled.
1929
+ """
1930
+ if context.executing_eagerly():
1931
+ raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
1932
+ 'supported when eager execution is enabled.')
1933
+
1934
+ with variable_scope.variable_scope(name, 'true_positives',
1935
+ (predictions, labels, weights)):
1936
+ values, update_ops = _confusion_matrix_at_thresholds(
1937
+ labels, predictions, thresholds, weights=weights, includes=('tp',))
1938
+
1939
+ tp_value = _aggregate_variable(values['tp'], metrics_collections)
1940
+
1941
+ if updates_collections:
1942
+ ops.add_to_collections(updates_collections, update_ops['tp'])
1943
+
1944
+ return tp_value, update_ops['tp']
1945
+
1946
+
1947
+ @tf_export('metrics.precision')
1948
+ def precision(labels,
1949
+ predictions,
1950
+ weights=None,
1951
+ metrics_collections=None,
1952
+ updates_collections=None,
1953
+ name=None):
1954
+ """Computes the precision of the predictions with respect to the labels.
1955
+
1956
+ The `precision` function creates two local variables,
1957
+ `true_positives` and `false_positives`, that are used to compute the
1958
+ precision. This value is ultimately returned as `precision`, an idempotent
1959
+ operation that simply divides `true_positives` by the sum of `true_positives`
1960
+ and `false_positives`.
1961
+
1962
+ For estimation of the metric over a stream of data, the function creates an
1963
+ `update_op` operation that updates these variables and returns the
1964
+ `precision`. `update_op` weights each prediction by the corresponding value in
1965
+ `weights`.
1966
+
1967
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1968
+
1969
+ Args:
1970
+ labels: The ground truth values, a `Tensor` whose dimensions must match
1971
+ `predictions`. Will be cast to `bool`.
1972
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1973
+ be cast to `bool`.
1974
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
1975
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1976
+ be either `1`, or the same as the corresponding `labels` dimension).
1977
+ metrics_collections: An optional list of collections that `precision` should
1978
+ be added to.
1979
+ updates_collections: An optional list of collections that `update_op` should
1980
+ be added to.
1981
+ name: An optional variable_scope name.
1982
+
1983
+ Returns:
1984
+ precision: Scalar float `Tensor` with the value of `true_positives`
1985
+ divided by the sum of `true_positives` and `false_positives`.
1986
+ update_op: `Operation` that increments `true_positives` and
1987
+ `false_positives` variables appropriately and whose value matches
1988
+ `precision`.
1989
+
1990
+ Raises:
1991
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
1992
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
1993
+ either `metrics_collections` or `updates_collections` are not a list or
1994
+ tuple.
1995
+ RuntimeError: If eager execution is enabled.
1996
+ """
1997
+ if context.executing_eagerly():
1998
+ raise RuntimeError('tf.metrics.precision is not '
1999
+ 'supported when eager execution is enabled.')
2000
+
2001
+ with variable_scope.variable_scope(name, 'precision',
2002
+ (predictions, labels, weights)):
2003
+
2004
+ predictions, labels, weights = _remove_squeezable_dimensions(
2005
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2006
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
2007
+ weights=weights)
2008
+
2009
+ true_p, true_positives_update_op = true_positives(
2010
+ labels,
2011
+ predictions,
2012
+ weights,
2013
+ metrics_collections=None,
2014
+ updates_collections=None,
2015
+ name=None)
2016
+ false_p, false_positives_update_op = false_positives(
2017
+ labels,
2018
+ predictions,
2019
+ weights,
2020
+ metrics_collections=None,
2021
+ updates_collections=None,
2022
+ name=None)
2023
+
2024
+ def compute_precision(tp, fp, name):
2025
+ return array_ops.where(
2026
+ math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
2027
+
2028
+ def once_across_towers(_, true_p, false_p):
2029
+ return compute_precision(true_p, false_p, 'value')
2030
+
2031
+ p = _aggregate_across_towers(metrics_collections, once_across_towers,
2032
+ true_p, false_p)
2033
+
2034
+ update_op = compute_precision(true_positives_update_op,
2035
+ false_positives_update_op, 'update_op')
2036
+ if updates_collections:
2037
+ ops.add_to_collections(updates_collections, update_op)
2038
+
2039
+ return p, update_op
2040
+
2041
+
2042
+ @tf_export('metrics.precision_at_thresholds')
2043
+ def precision_at_thresholds(labels,
2044
+ predictions,
2045
+ thresholds,
2046
+ weights=None,
2047
+ metrics_collections=None,
2048
+ updates_collections=None,
2049
+ name=None):
2050
+ """Computes precision values for different `thresholds` on `predictions`.
2051
+
2052
+ The `precision_at_thresholds` function creates four local variables,
2053
+ `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2054
+ for various values of thresholds. `precision[i]` is defined as the total
2055
+ weight of values in `predictions` above `thresholds[i]` whose corresponding
2056
+ entry in `labels` is `True`, divided by the total weight of values in
2057
+ `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2058
+ false_positives[i])`).
2059
+
2060
+ For estimation of the metric over a stream of data, the function creates an
2061
+ `update_op` operation that updates these variables and returns the
2062
+ `precision`.
2063
+
2064
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2065
+
2066
+ Args:
2067
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2068
+ `predictions`. Will be cast to `bool`.
2069
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2070
+ are in the range `[0, 1]`.
2071
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2072
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2073
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2074
+ be either `1`, or the same as the corresponding `labels` dimension).
2075
+ metrics_collections: An optional list of collections that `auc` should be
2076
+ added to.
2077
+ updates_collections: An optional list of collections that `update_op` should
2078
+ be added to.
2079
+ name: An optional variable_scope name.
2080
+
2081
+ Returns:
2082
+ precision: A float `Tensor` of shape `[len(thresholds)]`.
2083
+ update_op: An operation that increments the `true_positives`,
2084
+ `true_negatives`, `false_positives` and `false_negatives` variables that
2085
+ are used in the computation of `precision`.
2086
+
2087
+ Raises:
2088
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2089
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2090
+ either `metrics_collections` or `updates_collections` are not a list or
2091
+ tuple.
2092
+ RuntimeError: If eager execution is enabled.
2093
+ """
2094
+ if context.executing_eagerly():
2095
+ raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2096
+ 'supported when eager execution is enabled.')
2097
+
2098
+ with variable_scope.variable_scope(name, 'precision_at_thresholds',
2099
+ (predictions, labels, weights)):
2100
+ values, update_ops = _confusion_matrix_at_thresholds(
2101
+ labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2102
+
2103
+ # Avoid division by zero.
2104
+ epsilon = 1e-7
2105
+
2106
+ def compute_precision(tp, fp, name):
2107
+ return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
2108
+
2109
+ def precision_across_towers(_, values):
2110
+ return compute_precision(values['tp'], values['fp'], 'value')
2111
+
2112
+ prec = _aggregate_across_towers(metrics_collections,
2113
+ precision_across_towers, values)
2114
+
2115
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2116
+ 'update_op')
2117
+ if updates_collections:
2118
+ ops.add_to_collections(updates_collections, update_op)
2119
+
2120
+ return prec, update_op
2121
+
2122
+
2123
+ @tf_export('metrics.recall')
2124
+ def recall(labels,
2125
+ predictions,
2126
+ weights=None,
2127
+ metrics_collections=None,
2128
+ updates_collections=None,
2129
+ name=None):
2130
+ """Computes the recall of the predictions with respect to the labels.
2131
+
2132
+ The `recall` function creates two local variables, `true_positives`
2133
+ and `false_negatives`, that are used to compute the recall. This value is
2134
+ ultimately returned as `recall`, an idempotent operation that simply divides
2135
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
2136
+
2137
+ For estimation of the metric over a stream of data, the function creates an
2138
+ `update_op` that updates these variables and returns the `recall`. `update_op`
2139
+ weights each prediction by the corresponding value in `weights`.
2140
+
2141
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2142
+
2143
+ Args:
2144
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2145
+ `predictions`. Will be cast to `bool`.
2146
+ predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2147
+ be cast to `bool`.
2148
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2149
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2150
+ be either `1`, or the same as the corresponding `labels` dimension).
2151
+ metrics_collections: An optional list of collections that `recall` should
2152
+ be added to.
2153
+ updates_collections: An optional list of collections that `update_op` should
2154
+ be added to.
2155
+ name: An optional variable_scope name.
2156
+
2157
+ Returns:
2158
+ recall: Scalar float `Tensor` with the value of `true_positives` divided
2159
+ by the sum of `true_positives` and `false_negatives`.
2160
+ update_op: `Operation` that increments `true_positives` and
2161
+ `false_negatives` variables appropriately and whose value matches
2162
+ `recall`.
2163
+
2164
+ Raises:
2165
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2166
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2167
+ either `metrics_collections` or `updates_collections` are not a list or
2168
+ tuple.
2169
+ RuntimeError: If eager execution is enabled.
2170
+ """
2171
+ if context.executing_eagerly():
2172
+ raise RuntimeError('tf.metrics.recall is not supported is not '
2173
+ 'supported when eager execution is enabled.')
2174
+
2175
+ with variable_scope.variable_scope(name, 'recall',
2176
+ (predictions, labels, weights)):
2177
+ predictions, labels, weights = _remove_squeezable_dimensions(
2178
+ predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2179
+ labels=math_ops.cast(labels, dtype=dtypes.bool),
2180
+ weights=weights)
2181
+
2182
+ true_p, true_positives_update_op = true_positives(
2183
+ labels,
2184
+ predictions,
2185
+ weights,
2186
+ metrics_collections=None,
2187
+ updates_collections=None,
2188
+ name=None)
2189
+ false_n, false_negatives_update_op = false_negatives(
2190
+ labels,
2191
+ predictions,
2192
+ weights,
2193
+ metrics_collections=None,
2194
+ updates_collections=None,
2195
+ name=None)
2196
+
2197
+ def compute_recall(true_p, false_n, name):
2198
+ return array_ops.where(
2199
+ math_ops.greater(true_p + false_n, 0),
2200
+ math_ops.div(true_p, true_p + false_n), 0, name)
2201
+
2202
+ def once_across_towers(_, true_p, false_n):
2203
+ return compute_recall(true_p, false_n, 'value')
2204
+
2205
+ rec = _aggregate_across_towers(metrics_collections, once_across_towers,
2206
+ true_p, false_n)
2207
+
2208
+ update_op = compute_recall(true_positives_update_op,
2209
+ false_negatives_update_op, 'update_op')
2210
+ if updates_collections:
2211
+ ops.add_to_collections(updates_collections, update_op)
2212
+
2213
+ return rec, update_op
2214
+
2215
+
2216
+ def _at_k_name(name, k=None, class_id=None):
2217
+ if k is not None:
2218
+ name = '%s_at_%d' % (name, k)
2219
+ else:
2220
+ name = '%s_at_k' % (name)
2221
+ if class_id is not None:
2222
+ name = '%s_class%d' % (name, class_id)
2223
+ return name
2224
+
2225
+
2226
+ def _select_class_id(ids, selected_id):
2227
+ """Filter all but `selected_id` out of `ids`.
2228
+
2229
+ Args:
2230
+ ids: `int64` `Tensor` or `SparseTensor` of IDs.
2231
+ selected_id: Int id to select.
2232
+
2233
+ Returns:
2234
+ `SparseTensor` of same dimensions as `ids`. This contains only the entries
2235
+ equal to `selected_id`.
2236
+ """
2237
+ ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2238
+ if isinstance(ids, sparse_tensor.SparseTensor):
2239
+ return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2240
+ selected_id))
2241
+
2242
+ # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2243
+ # tf.equal and tf.reduce_any?
2244
+
2245
+ # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2246
+ ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2247
+ ids_last_dim = array_ops.size(ids_shape) - 1
2248
+ filled_selected_id_shape = math_ops.reduced_shape(
2249
+ ids_shape, array_ops.reshape(ids_last_dim, [1]))
2250
+
2251
+ # Intersect `ids` with the selected ID.
2252
+ filled_selected_id = array_ops.fill(filled_selected_id_shape,
2253
+ math_ops.to_int64(selected_id))
2254
+ result = sets.set_intersection(filled_selected_id, ids)
2255
+ return sparse_tensor.SparseTensor(
2256
+ indices=result.indices, values=result.values, dense_shape=ids_shape)
2257
+
2258
+
2259
+ def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2260
+ """If class ID is specified, filter all other classes.
2261
+
2262
+ Args:
2263
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2264
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2265
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2266
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2267
+ `predictions_idx`.
2268
+ predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2269
+ where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2270
+ [batch size, k].
2271
+ selected_id: Int id to select.
2272
+
2273
+ Returns:
2274
+ Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2275
+ """
2276
+ if selected_id is None:
2277
+ return labels, predictions_idx
2278
+ return (_select_class_id(labels, selected_id),
2279
+ _select_class_id(predictions_idx, selected_id))
2280
+
2281
+
2282
+ def _sparse_true_positive_at_k(labels,
2283
+ predictions_idx,
2284
+ class_id=None,
2285
+ weights=None,
2286
+ name=None):
2287
+ """Calculates true positives for recall@k and precision@k.
2288
+
2289
+ If `class_id` is specified, calculate binary true positives for `class_id`
2290
+ only.
2291
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2292
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2293
+
2294
+ Args:
2295
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2296
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2297
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2298
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2299
+ `predictions_idx`.
2300
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2301
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2302
+ match `labels`.
2303
+ class_id: Class for which we want binary metrics.
2304
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2305
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2306
+ dimensions must be either `1`, or the same as the corresponding `labels`
2307
+ dimension).
2308
+ name: Name of operation.
2309
+
2310
+ Returns:
2311
+ A [D1, ... DN] `Tensor` of true positive counts.
2312
+ """
2313
+ with ops.name_scope(name, 'true_positives',
2314
+ (predictions_idx, labels, weights)):
2315
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2316
+ class_id)
2317
+ tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2318
+ tp = math_ops.to_double(tp)
2319
+ if weights is not None:
2320
+ with ops.control_dependencies(
2321
+ (weights_broadcast_ops.assert_broadcastable(weights, tp),)):
2322
+ weights = math_ops.to_double(weights)
2323
+ tp = math_ops.multiply(tp, weights)
2324
+ return tp
2325
+
2326
+
2327
+ def _streaming_sparse_true_positive_at_k(labels,
2328
+ predictions_idx,
2329
+ k=None,
2330
+ class_id=None,
2331
+ weights=None,
2332
+ name=None):
2333
+ """Calculates weighted per step true positives for recall@k and precision@k.
2334
+
2335
+ If `class_id` is specified, calculate binary true positives for `class_id`
2336
+ only.
2337
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2338
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
2339
+
2340
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2341
+
2342
+ Args:
2343
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2344
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2345
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2346
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2347
+ `predictions_idx`.
2348
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2349
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2350
+ match `labels`.
2351
+ k: Integer, k for @k metric. This is only used for default op name.
2352
+ class_id: Class for which we want binary metrics.
2353
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2354
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2355
+ dimensions must be either `1`, or the same as the corresponding `labels`
2356
+ dimension).
2357
+ name: Name of new variable, and namespace for other dependent ops.
2358
+
2359
+ Returns:
2360
+ A tuple of `Variable` and update `Operation`.
2361
+
2362
+ Raises:
2363
+ ValueError: If `weights` is not `None` and has an incompatible shape.
2364
+ """
2365
+ with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2366
+ (predictions_idx, labels, weights)) as scope:
2367
+ tp = _sparse_true_positive_at_k(
2368
+ predictions_idx=predictions_idx,
2369
+ labels=labels,
2370
+ class_id=class_id,
2371
+ weights=weights)
2372
+ batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
2373
+
2374
+ var = metric_variable([], dtypes.float64, name=scope)
2375
+ return var, state_ops.assign_add(
2376
+ var, batch_total_tp, name='update', use_locking=True)
2377
+
2378
+
2379
+ def _sparse_false_negative_at_k(labels,
2380
+ predictions_idx,
2381
+ class_id=None,
2382
+ weights=None):
2383
+ """Calculates false negatives for recall@k.
2384
+
2385
+ If `class_id` is specified, calculate binary true positives for `class_id`
2386
+ only.
2387
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2388
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2389
+
2390
+ Args:
2391
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2392
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2393
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2394
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2395
+ `predictions_idx`.
2396
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2397
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2398
+ match `labels`.
2399
+ class_id: Class for which we want binary metrics.
2400
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2401
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2402
+ dimensions must be either `1`, or the same as the corresponding `labels`
2403
+ dimension).
2404
+
2405
+ Returns:
2406
+ A [D1, ... DN] `Tensor` of false negative counts.
2407
+ """
2408
+ with ops.name_scope(None, 'false_negatives',
2409
+ (predictions_idx, labels, weights)):
2410
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2411
+ class_id)
2412
+ fn = sets.set_size(
2413
+ sets.set_difference(predictions_idx, labels, aminusb=False))
2414
+ fn = math_ops.to_double(fn)
2415
+ if weights is not None:
2416
+ with ops.control_dependencies(
2417
+ (weights_broadcast_ops.assert_broadcastable(weights, fn),)):
2418
+ weights = math_ops.to_double(weights)
2419
+ fn = math_ops.multiply(fn, weights)
2420
+ return fn
2421
+
2422
+
2423
+ def _streaming_sparse_false_negative_at_k(labels,
2424
+ predictions_idx,
2425
+ k,
2426
+ class_id=None,
2427
+ weights=None,
2428
+ name=None):
2429
+ """Calculates weighted per step false negatives for recall@k.
2430
+
2431
+ If `class_id` is specified, calculate binary true positives for `class_id`
2432
+ only.
2433
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
2434
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
2435
+
2436
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2437
+
2438
+ Args:
2439
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2440
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2441
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2442
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
2443
+ `predictions_idx`.
2444
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2445
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2446
+ match `labels`.
2447
+ k: Integer, k for @k metric. This is only used for default op name.
2448
+ class_id: Class for which we want binary metrics.
2449
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2450
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2451
+ dimensions must be either `1`, or the same as the corresponding `labels`
2452
+ dimension).
2453
+ name: Name of new variable, and namespace for other dependent ops.
2454
+
2455
+ Returns:
2456
+ A tuple of `Variable` and update `Operation`.
2457
+
2458
+ Raises:
2459
+ ValueError: If `weights` is not `None` and has an incompatible shape.
2460
+ """
2461
+ with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2462
+ (predictions_idx, labels, weights)) as scope:
2463
+ fn = _sparse_false_negative_at_k(
2464
+ predictions_idx=predictions_idx,
2465
+ labels=labels,
2466
+ class_id=class_id,
2467
+ weights=weights)
2468
+ batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
2469
+
2470
+ var = metric_variable([], dtypes.float64, name=scope)
2471
+ return var, state_ops.assign_add(
2472
+ var, batch_total_fn, name='update', use_locking=True)
2473
+
2474
+
2475
+ @tf_export('metrics.recall_at_k')
2476
+ def recall_at_k(labels,
2477
+ predictions,
2478
+ k,
2479
+ class_id=None,
2480
+ weights=None,
2481
+ metrics_collections=None,
2482
+ updates_collections=None,
2483
+ name=None):
2484
+ """Computes recall@k of the predictions with respect to sparse labels.
2485
+
2486
+ If `class_id` is specified, we calculate recall by considering only the
2487
+ entries in the batch for which `class_id` is in the label, and computing
2488
+ the fraction of them for which `class_id` is in the top-k `predictions`.
2489
+ If `class_id` is not specified, we'll calculate recall as how often on
2490
+ average a class among the labels of a batch entry is in the top-k
2491
+ `predictions`.
2492
+
2493
+ `sparse_recall_at_k` creates two local variables,
2494
+ `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2495
+ the recall_at_k frequency. This frequency is ultimately returned as
2496
+ `recall_at_<k>`: an idempotent operation that simply divides
2497
+ `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2498
+ `false_negative_at_<k>`).
2499
+
2500
+ For estimation of the metric over a stream of data, the function creates an
2501
+ `update_op` operation that updates these variables and returns the
2502
+ `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2503
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
2504
+ `labels` calculate the true positives and false negatives weighted by
2505
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
2506
+ `false_negative_at_<k>` using these values.
2507
+
2508
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2509
+
2510
+ Args:
2511
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2512
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2513
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
2514
+ the associated prediction. Commonly, N=1 and `labels` has shape
2515
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2516
+ should be in range [0, num_classes), where num_classes is the last
2517
+ dimension of `predictions`. Values outside this range always count
2518
+ towards `false_negative_at_<k>`.
2519
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2520
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2521
+ The final dimension contains the logit values for each class. [D1, ... DN]
2522
+ must match `labels`.
2523
+ k: Integer, k for @k metric.
2524
+ class_id: Integer class ID for which we want binary metrics. This should be
2525
+ in range [0, num_classes), where num_classes is the last dimension of
2526
+ `predictions`. If class_id is outside this range, the method returns NAN.
2527
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2528
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2529
+ dimensions must be either `1`, or the same as the corresponding `labels`
2530
+ dimension).
2531
+ metrics_collections: An optional list of collections that values should
2532
+ be added to.
2533
+ updates_collections: An optional list of collections that updates should
2534
+ be added to.
2535
+ name: Name of new update operation, and namespace for other dependent ops.
2536
+
2537
+ Returns:
2538
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2539
+ by the sum of `true_positives` and `false_negatives`.
2540
+ update_op: `Operation` that increments `true_positives` and
2541
+ `false_negatives` variables appropriately, and whose value matches
2542
+ `recall`.
2543
+
2544
+ Raises:
2545
+ ValueError: If `weights` is not `None` and its shape doesn't match
2546
+ `predictions`, or if either `metrics_collections` or `updates_collections`
2547
+ are not a list or tuple.
2548
+ RuntimeError: If eager execution is enabled.
2549
+ """
2550
+ if context.executing_eagerly():
2551
+ raise RuntimeError('tf.metrics.recall_at_k is not '
2552
+ 'supported when eager execution is enabled.')
2553
+
2554
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2555
+ (predictions, labels, weights)) as scope:
2556
+ _, top_k_idx = nn.top_k(predictions, k)
2557
+ return recall_at_top_k(
2558
+ labels=labels,
2559
+ predictions_idx=top_k_idx,
2560
+ k=k,
2561
+ class_id=class_id,
2562
+ weights=weights,
2563
+ metrics_collections=metrics_collections,
2564
+ updates_collections=updates_collections,
2565
+ name=scope)
2566
+
2567
+
2568
+ @tf_export('metrics.recall_at_top_k')
2569
+ def recall_at_top_k(labels,
2570
+ predictions_idx,
2571
+ k=None,
2572
+ class_id=None,
2573
+ weights=None,
2574
+ metrics_collections=None,
2575
+ updates_collections=None,
2576
+ name=None):
2577
+ """Computes recall@k of top-k predictions with respect to sparse labels.
2578
+
2579
+ Differs from `recall_at_k` in that predictions must be in the form of top `k`
2580
+ class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2581
+ for more details.
2582
+
2583
+ Args:
2584
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2585
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2586
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
2587
+ the associated prediction. Commonly, N=1 and `labels` has shape
2588
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2589
+ should be in range [0, num_classes), where num_classes is the last
2590
+ dimension of `predictions`. Values outside this range always count
2591
+ towards `false_negative_at_<k>`.
2592
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2593
+ Commonly, N=1 and predictions has shape [batch size, k]. The final
2594
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
2595
+ match `labels`.
2596
+ k: Integer, k for @k metric. Only used for the default op name.
2597
+ class_id: Integer class ID for which we want binary metrics. This should be
2598
+ in range [0, num_classes), where num_classes is the last dimension of
2599
+ `predictions`. If class_id is outside this range, the method returns NAN.
2600
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2601
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2602
+ dimensions must be either `1`, or the same as the corresponding `labels`
2603
+ dimension).
2604
+ metrics_collections: An optional list of collections that values should
2605
+ be added to.
2606
+ updates_collections: An optional list of collections that updates should
2607
+ be added to.
2608
+ name: Name of new update operation, and namespace for other dependent ops.
2609
+
2610
+ Returns:
2611
+ recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2612
+ by the sum of `true_positives` and `false_negatives`.
2613
+ update_op: `Operation` that increments `true_positives` and
2614
+ `false_negatives` variables appropriately, and whose value matches
2615
+ `recall`.
2616
+
2617
+ Raises:
2618
+ ValueError: If `weights` is not `None` and its shape doesn't match
2619
+ `predictions`, or if either `metrics_collections` or `updates_collections`
2620
+ are not a list or tuple.
2621
+ """
2622
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2623
+ (predictions_idx, labels, weights)) as scope:
2624
+ labels = _maybe_expand_labels(labels, predictions_idx)
2625
+ top_k_idx = math_ops.to_int64(predictions_idx)
2626
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
2627
+ predictions_idx=top_k_idx,
2628
+ labels=labels,
2629
+ k=k,
2630
+ class_id=class_id,
2631
+ weights=weights)
2632
+ fn, fn_update = _streaming_sparse_false_negative_at_k(
2633
+ predictions_idx=top_k_idx,
2634
+ labels=labels,
2635
+ k=k,
2636
+ class_id=class_id,
2637
+ weights=weights)
2638
+
2639
+ def compute_recall(_, tp, fn):
2640
+ return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
2641
+
2642
+ metric = _aggregate_across_towers(metrics_collections, compute_recall, tp,
2643
+ fn)
2644
+
2645
+ update = math_ops.div(
2646
+ tp_update, math_ops.add(tp_update, fn_update), name='update')
2647
+ if updates_collections:
2648
+ ops.add_to_collections(updates_collections, update)
2649
+ return metric, update
2650
+
2651
+
2652
+ @tf_export('metrics.recall_at_thresholds')
2653
+ def recall_at_thresholds(labels,
2654
+ predictions,
2655
+ thresholds,
2656
+ weights=None,
2657
+ metrics_collections=None,
2658
+ updates_collections=None,
2659
+ name=None):
2660
+ """Computes various recall values for different `thresholds` on `predictions`.
2661
+
2662
+ The `recall_at_thresholds` function creates four local variables,
2663
+ `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2664
+ for various values of thresholds. `recall[i]` is defined as the total weight
2665
+ of values in `predictions` above `thresholds[i]` whose corresponding entry in
2666
+ `labels` is `True`, divided by the total weight of `True` values in `labels`
2667
+ (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2668
+
2669
+ For estimation of the metric over a stream of data, the function creates an
2670
+ `update_op` operation that updates these variables and returns the `recall`.
2671
+
2672
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2673
+
2674
+ Args:
2675
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2676
+ `predictions`. Will be cast to `bool`.
2677
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2678
+ are in the range `[0, 1]`.
2679
+ thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2680
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2681
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2682
+ be either `1`, or the same as the corresponding `labels` dimension).
2683
+ metrics_collections: An optional list of collections that `recall` should be
2684
+ added to.
2685
+ updates_collections: An optional list of collections that `update_op` should
2686
+ be added to.
2687
+ name: An optional variable_scope name.
2688
+
2689
+ Returns:
2690
+ recall: A float `Tensor` of shape `[len(thresholds)]`.
2691
+ update_op: An operation that increments the `true_positives`,
2692
+ `true_negatives`, `false_positives` and `false_negatives` variables that
2693
+ are used in the computation of `recall`.
2694
+
2695
+ Raises:
2696
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2697
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2698
+ either `metrics_collections` or `updates_collections` are not a list or
2699
+ tuple.
2700
+ RuntimeError: If eager execution is enabled.
2701
+ """
2702
+ if context.executing_eagerly():
2703
+ raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2704
+ 'supported when eager execution is enabled.')
2705
+
2706
+ with variable_scope.variable_scope(name, 'recall_at_thresholds',
2707
+ (predictions, labels, weights)):
2708
+ values, update_ops = _confusion_matrix_at_thresholds(
2709
+ labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2710
+
2711
+ # Avoid division by zero.
2712
+ epsilon = 1e-7
2713
+
2714
+ def compute_recall(tp, fn, name):
2715
+ return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
2716
+
2717
+ def recall_across_towers(_, values):
2718
+ return compute_recall(values['tp'], values['fn'], 'value')
2719
+
2720
+ rec = _aggregate_across_towers(metrics_collections, recall_across_towers,
2721
+ values)
2722
+
2723
+ update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2724
+ if updates_collections:
2725
+ ops.add_to_collections(updates_collections, update_op)
2726
+
2727
+ return rec, update_op
2728
+
2729
+
2730
+ @tf_export('metrics.root_mean_squared_error')
2731
+ def root_mean_squared_error(labels,
2732
+ predictions,
2733
+ weights=None,
2734
+ metrics_collections=None,
2735
+ updates_collections=None,
2736
+ name=None):
2737
+ """Computes the root mean squared error between the labels and predictions.
2738
+
2739
+ The `root_mean_squared_error` function creates two local variables,
2740
+ `total` and `count` that are used to compute the root mean squared error.
2741
+ This average is weighted by `weights`, and it is ultimately returned as
2742
+ `root_mean_squared_error`: an idempotent operation that takes the square root
2743
+ of the division of `total` by `count`.
2744
+
2745
+ For estimation of the metric over a stream of data, the function creates an
2746
+ `update_op` operation that updates these variables and returns the
2747
+ `root_mean_squared_error`. Internally, a `squared_error` operation computes
2748
+ the element-wise square of the difference between `predictions` and `labels`.
2749
+ Then `update_op` increments `total` with the reduced sum of the product of
2750
+ `weights` and `squared_error`, and it increments `count` with the reduced sum
2751
+ of `weights`.
2752
+
2753
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2754
+
2755
+ Args:
2756
+ labels: A `Tensor` of the same shape as `predictions`.
2757
+ predictions: A `Tensor` of arbitrary shape.
2758
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2759
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2760
+ be either `1`, or the same as the corresponding `labels` dimension).
2761
+ metrics_collections: An optional list of collections that
2762
+ `root_mean_squared_error` should be added to.
2763
+ updates_collections: An optional list of collections that `update_op` should
2764
+ be added to.
2765
+ name: An optional variable_scope name.
2766
+
2767
+ Returns:
2768
+ root_mean_squared_error: A `Tensor` representing the current mean, the value
2769
+ of `total` divided by `count`.
2770
+ update_op: An operation that increments the `total` and `count` variables
2771
+ appropriately and whose value matches `root_mean_squared_error`.
2772
+
2773
+ Raises:
2774
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
2775
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2776
+ either `metrics_collections` or `updates_collections` are not a list or
2777
+ tuple.
2778
+ RuntimeError: If eager execution is enabled.
2779
+ """
2780
+ if context.executing_eagerly():
2781
+ raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2782
+ 'supported when eager execution is enabled.')
2783
+
2784
+ predictions, labels, weights = _remove_squeezable_dimensions(
2785
+ predictions=predictions, labels=labels, weights=weights)
2786
+ mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2787
+ None, name or
2788
+ 'root_mean_squared_error')
2789
+
2790
+ once_across_towers = lambda _, mse: math_ops.sqrt(mse) # noqa: E731
2791
+ rmse = _aggregate_across_towers(metrics_collections, once_across_towers, mse)
2792
+
2793
+ update_rmse_op = math_ops.sqrt(update_mse_op)
2794
+ if updates_collections:
2795
+ ops.add_to_collections(updates_collections, update_rmse_op)
2796
+
2797
+ return rmse, update_rmse_op
2798
+
2799
+
2800
+ @tf_export('metrics.sensitivity_at_specificity')
2801
+ def sensitivity_at_specificity(labels,
2802
+ predictions,
2803
+ specificity,
2804
+ weights=None,
2805
+ num_thresholds=200,
2806
+ metrics_collections=None,
2807
+ updates_collections=None,
2808
+ name=None):
2809
+ """Computes the specificity at a given sensitivity.
2810
+
2811
+ The `sensitivity_at_specificity` function creates four local
2812
+ variables, `true_positives`, `true_negatives`, `false_positives` and
2813
+ `false_negatives` that are used to compute the sensitivity at the given
2814
+ specificity value. The threshold for the given specificity value is computed
2815
+ and used to evaluate the corresponding sensitivity.
2816
+
2817
+ For estimation of the metric over a stream of data, the function creates an
2818
+ `update_op` operation that updates these variables and returns the
2819
+ `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
2820
+ `false_positives` and `false_negatives` counts with the weight of each case
2821
+ found in the `predictions` and `labels`.
2822
+
2823
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2824
+
2825
+ For additional information about specificity and sensitivity, see the
2826
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
2827
+
2828
+ Args:
2829
+ labels: The ground truth values, a `Tensor` whose dimensions must match
2830
+ `predictions`. Will be cast to `bool`.
2831
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
2832
+ are in the range `[0, 1]`.
2833
+ specificity: A scalar value in range `[0, 1]`.
2834
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
2835
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2836
+ be either `1`, or the same as the corresponding `labels` dimension).
2837
+ num_thresholds: The number of thresholds to use for matching the given
2838
+ specificity.
2839
+ metrics_collections: An optional list of collections that `sensitivity`
2840
+ should be added to.
2841
+ updates_collections: An optional list of collections that `update_op` should
2842
+ be added to.
2843
+ name: An optional variable_scope name.
2844
+
2845
+ Returns:
2846
+ sensitivity: A scalar `Tensor` representing the sensitivity at the given
2847
+ `specificity` value.
2848
+ update_op: An operation that increments the `true_positives`,
2849
+ `true_negatives`, `false_positives` and `false_negatives` variables
2850
+ appropriately and whose value matches `sensitivity`.
2851
+
2852
+ Raises:
2853
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
2854
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
2855
+ `specificity` is not between 0 and 1, or if either `metrics_collections`
2856
+ or `updates_collections` are not a list or tuple.
2857
+ RuntimeError: If eager execution is enabled.
2858
+ """
2859
+ if context.executing_eagerly():
2860
+ raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
2861
+ 'supported when eager execution is enabled.')
2862
+
2863
+ if specificity < 0 or specificity > 1:
2864
+ raise ValueError('`specificity` must be in the range [0, 1].')
2865
+
2866
+ with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
2867
+ (predictions, labels, weights)):
2868
+ kepsilon = 1e-7 # to account for floating point imprecisions
2869
+ thresholds = [
2870
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
2871
+ ]
2872
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
2873
+
2874
+ values, update_ops = _confusion_matrix_at_thresholds(
2875
+ labels, predictions, thresholds, weights)
2876
+
2877
+ def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
2878
+ specificities = math_ops.div(tn, tn + fp + kepsilon)
2879
+ tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
2880
+ tf_index = math_ops.cast(tf_index, dtypes.int32)
2881
+
2882
+ # Now, we have the implicit threshold, so compute the sensitivity:
2883
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
2884
+ name)
2885
+
2886
+ def sensitivity_across_towers(_, values):
2887
+ return compute_sensitivity_at_specificity(values['tp'], values['tn'],
2888
+ values['fp'], values['fn'],
2889
+ 'value')
2890
+
2891
+ sensitivity = _aggregate_across_towers(metrics_collections,
2892
+ sensitivity_across_towers, values)
2893
+
2894
+ update_op = compute_sensitivity_at_specificity(update_ops['tp'],
2895
+ update_ops['tn'],
2896
+ update_ops['fp'],
2897
+ update_ops['fn'],
2898
+ 'update_op')
2899
+ if updates_collections:
2900
+ ops.add_to_collections(updates_collections, update_op)
2901
+
2902
+ return sensitivity, update_op
2903
+
2904
+
2905
+ def _expand_and_tile(tensor, multiple, dim=0, name=None):
2906
+ """Slice `tensor` shape in 2, then tile along the sliced dimension.
2907
+
2908
+ A new dimension is inserted in shape of `tensor` before `dim`, then values are
2909
+ tiled `multiple` times along the new dimension.
2910
+
2911
+ Args:
2912
+ tensor: Input `Tensor` or `SparseTensor`.
2913
+ multiple: Integer, number of times to tile.
2914
+ dim: Integer, dimension along which to tile.
2915
+ name: Name of operation.
2916
+
2917
+ Returns:
2918
+ `Tensor` result of expanding and tiling `tensor`.
2919
+
2920
+ Raises:
2921
+ ValueError: if `multiple` is less than 1, or `dim` is not in
2922
+ `[-rank(tensor), rank(tensor)]`.
2923
+ """
2924
+ if multiple < 1:
2925
+ raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
2926
+ with ops.name_scope(name, 'expand_and_tile',
2927
+ (tensor, multiple, dim)) as scope:
2928
+ # Sparse.
2929
+ tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
2930
+ if isinstance(tensor, sparse_tensor.SparseTensor):
2931
+ if dim < 0:
2932
+ expand_dims = array_ops.reshape(
2933
+ array_ops.size(tensor.dense_shape) + dim, [1])
2934
+ else:
2935
+ expand_dims = [dim]
2936
+ expanded_shape = array_ops.concat(
2937
+ (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
2938
+ array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
2939
+ 0,
2940
+ name='expanded_shape')
2941
+ expanded = sparse_ops.sparse_reshape(
2942
+ tensor, shape=expanded_shape, name='expand')
2943
+ if multiple == 1:
2944
+ return expanded
2945
+ return sparse_ops.sparse_concat(
2946
+ dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
2947
+
2948
+ # Dense.
2949
+ expanded = array_ops.expand_dims(
2950
+ tensor, dim if (dim >= 0) else (dim - 1), name='expand')
2951
+ if multiple == 1:
2952
+ return expanded
2953
+ ones = array_ops.ones_like(array_ops.shape(tensor))
2954
+ tile_multiples = array_ops.concat((ones[:dim], (multiple,), ones[dim:]),
2955
+ 0,
2956
+ name='multiples')
2957
+ return array_ops.tile(expanded, tile_multiples, name=scope)
2958
+
2959
+
2960
+ def _num_relevant(labels, k):
2961
+ """Computes number of relevant values for each row in labels.
2962
+
2963
+ For labels with shape [D1, ... DN, num_labels], this is the minimum of
2964
+ `num_labels` and `k`.
2965
+
2966
+ Args:
2967
+ labels: `int64` `Tensor` or `SparseTensor` with shape
2968
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2969
+ target classes for the associated prediction. Commonly, N=1 and `labels`
2970
+ has shape [batch_size, num_labels].
2971
+ k: Integer, k for @k metric.
2972
+
2973
+ Returns:
2974
+ Integer `Tensor` of shape [D1, ... DN], where each value is the number of
2975
+ relevant values for that row.
2976
+
2977
+ Raises:
2978
+ ValueError: if inputs have invalid dtypes or values.
2979
+ """
2980
+ if k < 1:
2981
+ raise ValueError('Invalid k=%s.' % k)
2982
+ with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
2983
+ # For SparseTensor, calculate separate count for each row.
2984
+ labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
2985
+ if isinstance(labels, sparse_tensor.SparseTensor):
2986
+ return math_ops.minimum(sets.set_size(labels), k, name=scope)
2987
+
2988
+ # For dense Tensor, calculate scalar count based on last dimension, and
2989
+ # tile across labels shape.
2990
+ labels_shape = array_ops.shape(labels)
2991
+ labels_size = labels_shape[-1]
2992
+ num_relevant_scalar = math_ops.minimum(labels_size, k)
2993
+ return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
2994
+
2995
+
2996
+ def _sparse_average_precision_at_top_k(labels, predictions_idx):
2997
+ """Computes average precision@k of predictions with respect to sparse labels.
2998
+
2999
+ From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3000
+ for each row is:
3001
+
3002
+ AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3003
+
3004
+ A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3005
+ `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3006
+ Each row of the results contains the average precision for that row.
3007
+
3008
+ Args:
3009
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3010
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3011
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3012
+ the associated prediction. Commonly, N=1 and `labels` has shape
3013
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3014
+ Values should be in range [0, num_classes).
3015
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3016
+ Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3017
+ dimension must be set and contains the top `k` predicted class indices.
3018
+ [D1, ... DN] must match `labels`. Values should be in range
3019
+ [0, num_classes).
3020
+
3021
+ Returns:
3022
+ `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3023
+ precision for that row.
3024
+
3025
+ Raises:
3026
+ ValueError: if the last dimension of predictions_idx is not set.
3027
+ """
3028
+ with ops.name_scope(None, 'average_precision',
3029
+ (predictions_idx, labels)) as scope:
3030
+ predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
3031
+ if predictions_idx.get_shape().ndims == 0:
3032
+ raise ValueError('The rank of predictions_idx must be at least 1.')
3033
+ k = predictions_idx.get_shape().as_list()[-1]
3034
+ if k is None:
3035
+ raise ValueError('The last dimension of predictions_idx must be set.')
3036
+ labels = _maybe_expand_labels(labels, predictions_idx)
3037
+
3038
+ # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3039
+ # prediction for each k, so we can calculate separate true positive values
3040
+ # for each k.
3041
+ predictions_idx_per_k = array_ops.expand_dims(
3042
+ predictions_idx, -1, name='predictions_idx_per_k')
3043
+
3044
+ # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3045
+ labels_per_k = _expand_and_tile(
3046
+ labels, multiple=k, dim=-1, name='labels_per_k')
3047
+
3048
+ # The following tensors are all of shape [D1, ... DN, k], containing values
3049
+ # per row, per k value.
3050
+ # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3051
+ # that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3052
+ # the formula above.
3053
+ # `tp_per_k` (int32) - True positive counts.
3054
+ # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3055
+ # the precision denominator.
3056
+ # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3057
+ # term from the formula above.
3058
+ # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3059
+ # precisions at all k for which relevance indicator is true.
3060
+ relevant_per_k = _sparse_true_positive_at_k(
3061
+ labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3062
+ tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3063
+ retrieved_per_k = math_ops.cumsum(
3064
+ array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3065
+ precision_per_k = math_ops.div(
3066
+ math_ops.to_double(tp_per_k),
3067
+ math_ops.to_double(retrieved_per_k),
3068
+ name='precision_per_k')
3069
+ relevant_precision_per_k = math_ops.multiply(
3070
+ precision_per_k,
3071
+ math_ops.to_double(relevant_per_k),
3072
+ name='relevant_precision_per_k')
3073
+
3074
+ # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3075
+ precision_sum = math_ops.reduce_sum(
3076
+ relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
3077
+
3078
+ # Divide by number of relevant items to get average precision. These are
3079
+ # the "num_relevant_items" and "AveP" terms from the formula above.
3080
+ num_relevant_items = math_ops.to_double(_num_relevant(labels, k))
3081
+ return math_ops.div(precision_sum, num_relevant_items, name=scope)
3082
+
3083
+
3084
+ def _streaming_sparse_average_precision_at_top_k(labels,
3085
+ predictions_idx,
3086
+ weights=None,
3087
+ metrics_collections=None,
3088
+ updates_collections=None,
3089
+ name=None):
3090
+ """Computes average precision@k of predictions with respect to sparse labels.
3091
+
3092
+ `sparse_average_precision_at_top_k` creates two local variables,
3093
+ `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3094
+ are used to compute the frequency. This frequency is ultimately returned as
3095
+ `average_precision_at_<k>`: an idempotent operation that simply divides
3096
+ `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3097
+
3098
+ For estimation of the metric over a stream of data, the function creates an
3099
+ `update_op` operation that updates these variables and returns the
3100
+ `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3101
+ the true positives and false positives weighted by `weights`. Then `update_op`
3102
+ increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3103
+ values.
3104
+
3105
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3106
+
3107
+ Args:
3108
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3109
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3110
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3111
+ the associated prediction. Commonly, N=1 and `labels` has shape
3112
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3113
+ Values should be in range [0, num_classes).
3114
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3115
+ Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3116
+ dimension contains the top `k` predicted class indices. [D1, ... DN] must
3117
+ match `labels`. Values should be in range [0, num_classes).
3118
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3119
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3120
+ dimensions must be either `1`, or the same as the corresponding `labels`
3121
+ dimension).
3122
+ metrics_collections: An optional list of collections that values should
3123
+ be added to.
3124
+ updates_collections: An optional list of collections that updates should
3125
+ be added to.
3126
+ name: Name of new update operation, and namespace for other dependent ops.
3127
+
3128
+ Returns:
3129
+ mean_average_precision: Scalar `float64` `Tensor` with the mean average
3130
+ precision values.
3131
+ update: `Operation` that increments variables appropriately, and whose
3132
+ value matches `metric`.
3133
+ """
3134
+ with ops.name_scope(name, 'average_precision_at_top_k',
3135
+ (predictions_idx, labels, weights)) as scope:
3136
+ # Calculate per-example average precision, and apply weights.
3137
+ average_precision = _sparse_average_precision_at_top_k(
3138
+ predictions_idx=predictions_idx, labels=labels)
3139
+ if weights is not None:
3140
+ weights = weights_broadcast_ops.broadcast_weights(
3141
+ math_ops.to_double(weights), average_precision)
3142
+ average_precision = math_ops.multiply(average_precision, weights)
3143
+
3144
+ # Create accumulation variables and update ops for max average precision and
3145
+ # total average precision.
3146
+ with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3147
+ # `max` is the max possible precision. Since max for any row is 1.0:
3148
+ # - For the unweighted case, this is just the number of rows.
3149
+ # - For the weighted case, it's the sum of the weights broadcast across
3150
+ # `average_precision` rows.
3151
+ max_var = metric_variable([], dtypes.float64, name=max_scope)
3152
+ if weights is None:
3153
+ batch_max = math_ops.to_double(
3154
+ array_ops.size(average_precision, name='batch_max'))
3155
+ else:
3156
+ batch_max = math_ops.reduce_sum(weights, name='batch_max')
3157
+ max_update = state_ops.assign_add(
3158
+ max_var, batch_max, name='update', use_locking=True)
3159
+ with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3160
+ total_var = metric_variable([], dtypes.float64, name=total_scope)
3161
+ batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3162
+ total_update = state_ops.assign_add(
3163
+ total_var, batch_total, name='update', use_locking=True)
3164
+
3165
+ # Divide total by max to get mean, for both vars and the update ops.
3166
+ def precision_across_towers(_, total_var, max_var):
3167
+ return _safe_scalar_div(total_var, max_var, name='mean')
3168
+
3169
+ mean_average_precision = _aggregate_across_towers(metrics_collections,
3170
+ precision_across_towers,
3171
+ total_var, max_var)
3172
+
3173
+ update = _safe_scalar_div(total_update, max_update, name=scope)
3174
+ if updates_collections:
3175
+ ops.add_to_collections(updates_collections, update)
3176
+
3177
+ return mean_average_precision, update
3178
+
3179
+
3180
+ @tf_export('metrics.sparse_average_precision_at_k')
3181
+ @deprecated(None, 'Use average_precision_at_k instead')
3182
+ def sparse_average_precision_at_k(labels,
3183
+ predictions,
3184
+ k,
3185
+ weights=None,
3186
+ metrics_collections=None,
3187
+ updates_collections=None,
3188
+ name=None):
3189
+ """Renamed to `average_precision_at_k`, please use that method instead."""
3190
+ return average_precision_at_k(
3191
+ labels=labels,
3192
+ predictions=predictions,
3193
+ k=k,
3194
+ weights=weights,
3195
+ metrics_collections=metrics_collections,
3196
+ updates_collections=updates_collections,
3197
+ name=name)
3198
+
3199
+
3200
+ @tf_export('metrics.average_precision_at_k')
3201
+ def average_precision_at_k(labels,
3202
+ predictions,
3203
+ k,
3204
+ weights=None,
3205
+ metrics_collections=None,
3206
+ updates_collections=None,
3207
+ name=None):
3208
+ """Computes average precision@k of predictions with respect to sparse labels.
3209
+
3210
+ `average_precision_at_k` creates two local variables,
3211
+ `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3212
+ are used to compute the frequency. This frequency is ultimately returned as
3213
+ `average_precision_at_<k>`: an idempotent operation that simply divides
3214
+ `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3215
+
3216
+ For estimation of the metric over a stream of data, the function creates an
3217
+ `update_op` operation that updates these variables and returns the
3218
+ `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3219
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
3220
+ `labels` calculate the true positives and false positives weighted by
3221
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
3222
+ `false_positive_at_<k>` using these values.
3223
+
3224
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3225
+
3226
+ Args:
3227
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3228
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3229
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3230
+ the associated prediction. Commonly, N=1 and `labels` has shape
3231
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3232
+ should be in range [0, num_classes), where num_classes is the last
3233
+ dimension of `predictions`. Values outside this range are ignored.
3234
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3235
+ N >= 1. Commonly, N=1 and `predictions` has shape
3236
+ [batch size, num_classes]. The final dimension contains the logit values
3237
+ for each class. [D1, ... DN] must match `labels`.
3238
+ k: Integer, k for @k metric. This will calculate an average precision for
3239
+ range `[1,k]`, as documented above.
3240
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3241
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3242
+ dimensions must be either `1`, or the same as the corresponding `labels`
3243
+ dimension).
3244
+ metrics_collections: An optional list of collections that values should
3245
+ be added to.
3246
+ updates_collections: An optional list of collections that updates should
3247
+ be added to.
3248
+ name: Name of new update operation, and namespace for other dependent ops.
3249
+
3250
+ Returns:
3251
+ mean_average_precision: Scalar `float64` `Tensor` with the mean average
3252
+ precision values.
3253
+ update: `Operation` that increments variables appropriately, and whose
3254
+ value matches `metric`.
3255
+
3256
+ Raises:
3257
+ ValueError: if k is invalid.
3258
+ RuntimeError: If eager execution is enabled.
3259
+ """
3260
+ if context.executing_eagerly():
3261
+ raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3262
+ 'supported when eager execution is enabled.')
3263
+
3264
+ if k < 1:
3265
+ raise ValueError('Invalid k=%s.' % k)
3266
+ with ops.name_scope(name, _at_k_name('average_precision', k),
3267
+ (predictions, labels, weights)) as scope:
3268
+ # Calculate top k indices to produce [D1, ... DN, k] tensor.
3269
+ _, predictions_idx = nn.top_k(predictions, k)
3270
+ return _streaming_sparse_average_precision_at_top_k(
3271
+ labels=labels,
3272
+ predictions_idx=predictions_idx,
3273
+ weights=weights,
3274
+ metrics_collections=metrics_collections,
3275
+ updates_collections=updates_collections,
3276
+ name=scope)
3277
+
3278
+
3279
+ def _sparse_false_positive_at_k(labels,
3280
+ predictions_idx,
3281
+ class_id=None,
3282
+ weights=None):
3283
+ """Calculates false positives for precision@k.
3284
+
3285
+ If `class_id` is specified, calculate binary true positives for `class_id`
3286
+ only.
3287
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
3288
+ `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3289
+
3290
+ Args:
3291
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3292
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3293
+ target classes for the associated prediction. Commonly, N=1 and `labels`
3294
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
3295
+ `predictions_idx`.
3296
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3297
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3298
+ match `labels`.
3299
+ class_id: Class for which we want binary metrics.
3300
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3301
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3302
+ dimensions must be either `1`, or the same as the corresponding `labels`
3303
+ dimension).
3304
+
3305
+ Returns:
3306
+ A [D1, ... DN] `Tensor` of false positive counts.
3307
+ """
3308
+ with ops.name_scope(None, 'false_positives',
3309
+ (predictions_idx, labels, weights)):
3310
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3311
+ class_id)
3312
+ fp = sets.set_size(
3313
+ sets.set_difference(predictions_idx, labels, aminusb=True))
3314
+ fp = math_ops.to_double(fp)
3315
+ if weights is not None:
3316
+ with ops.control_dependencies(
3317
+ (weights_broadcast_ops.assert_broadcastable(weights, fp),)):
3318
+ weights = math_ops.to_double(weights)
3319
+ fp = math_ops.multiply(fp, weights)
3320
+ return fp
3321
+
3322
+
3323
+ def _streaming_sparse_false_positive_at_k(labels,
3324
+ predictions_idx,
3325
+ k=None,
3326
+ class_id=None,
3327
+ weights=None,
3328
+ name=None):
3329
+ """Calculates weighted per step false positives for precision@k.
3330
+
3331
+ If `class_id` is specified, calculate binary true positives for `class_id`
3332
+ only.
3333
+ If `class_id` is not specified, calculate metrics for `k` predicted vs
3334
+ `n` label classes, where `n` is the 2nd dimension of `labels`.
3335
+
3336
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3337
+
3338
+ Args:
3339
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3340
+ [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3341
+ target classes for the associated prediction. Commonly, N=1 and `labels`
3342
+ has shape [batch_size, num_labels]. [D1, ... DN] must match
3343
+ `predictions_idx`.
3344
+ predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3345
+ top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3346
+ match `labels`.
3347
+ k: Integer, k for @k metric. This is only used for default op name.
3348
+ class_id: Class for which we want binary metrics.
3349
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3350
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3351
+ dimensions must be either `1`, or the same as the corresponding `labels`
3352
+ dimension).
3353
+ name: Name of new variable, and namespace for other dependent ops.
3354
+
3355
+ Returns:
3356
+ A tuple of `Variable` and update `Operation`.
3357
+
3358
+ Raises:
3359
+ ValueError: If `weights` is not `None` and has an incompatible shape.
3360
+ """
3361
+ with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3362
+ (predictions_idx, labels, weights)) as scope:
3363
+ fp = _sparse_false_positive_at_k(
3364
+ predictions_idx=predictions_idx,
3365
+ labels=labels,
3366
+ class_id=class_id,
3367
+ weights=weights)
3368
+ batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
3369
+
3370
+ var = metric_variable([], dtypes.float64, name=scope)
3371
+ return var, state_ops.assign_add(
3372
+ var, batch_total_fp, name='update', use_locking=True)
3373
+
3374
+
3375
+ @tf_export('metrics.precision_at_top_k')
3376
+ def precision_at_top_k(labels,
3377
+ predictions_idx,
3378
+ k=None,
3379
+ class_id=None,
3380
+ weights=None,
3381
+ metrics_collections=None,
3382
+ updates_collections=None,
3383
+ name=None):
3384
+ """Computes precision@k of the predictions with respect to sparse labels.
3385
+
3386
+ Differs from `sparse_precision_at_k` in that predictions must be in the form
3387
+ of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3388
+ Refer to `sparse_precision_at_k` for more details.
3389
+
3390
+ Args:
3391
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3392
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3393
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3394
+ the associated prediction. Commonly, N=1 and `labels` has shape
3395
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3396
+ should be in range [0, num_classes), where num_classes is the last
3397
+ dimension of `predictions`. Values outside this range are ignored.
3398
+ predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3399
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3400
+ The final dimension contains the top `k` predicted class indices.
3401
+ [D1, ... DN] must match `labels`.
3402
+ k: Integer, k for @k metric. Only used for the default op name.
3403
+ class_id: Integer class ID for which we want binary metrics. This should be
3404
+ in range [0, num_classes], where num_classes is the last dimension of
3405
+ `predictions`. If `class_id` is outside this range, the method returns
3406
+ NAN.
3407
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3408
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3409
+ dimensions must be either `1`, or the same as the corresponding `labels`
3410
+ dimension).
3411
+ metrics_collections: An optional list of collections that values should
3412
+ be added to.
3413
+ updates_collections: An optional list of collections that updates should
3414
+ be added to.
3415
+ name: Name of new update operation, and namespace for other dependent ops.
3416
+
3417
+ Returns:
3418
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
3419
+ divided by the sum of `true_positives` and `false_positives`.
3420
+ update_op: `Operation` that increments `true_positives` and
3421
+ `false_positives` variables appropriately, and whose value matches
3422
+ `precision`.
3423
+
3424
+ Raises:
3425
+ ValueError: If `weights` is not `None` and its shape doesn't match
3426
+ `predictions`, or if either `metrics_collections` or `updates_collections`
3427
+ are not a list or tuple.
3428
+ RuntimeError: If eager execution is enabled.
3429
+ """
3430
+ if context.executing_eagerly():
3431
+ raise RuntimeError('tf.metrics.precision_at_top_k is not '
3432
+ 'supported when eager execution is enabled.')
3433
+
3434
+ with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3435
+ (predictions_idx, labels, weights)) as scope:
3436
+ labels = _maybe_expand_labels(labels, predictions_idx)
3437
+ top_k_idx = math_ops.to_int64(predictions_idx)
3438
+ tp, tp_update = _streaming_sparse_true_positive_at_k(
3439
+ predictions_idx=top_k_idx,
3440
+ labels=labels,
3441
+ k=k,
3442
+ class_id=class_id,
3443
+ weights=weights)
3444
+ fp, fp_update = _streaming_sparse_false_positive_at_k(
3445
+ predictions_idx=top_k_idx,
3446
+ labels=labels,
3447
+ k=k,
3448
+ class_id=class_id,
3449
+ weights=weights)
3450
+
3451
+ def precision_across_towers(_, tp, fp):
3452
+ return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
3453
+
3454
+ metric = _aggregate_across_towers(metrics_collections,
3455
+ precision_across_towers, tp, fp)
3456
+
3457
+ update = math_ops.div(
3458
+ tp_update, math_ops.add(tp_update, fp_update), name='update')
3459
+ if updates_collections:
3460
+ ops.add_to_collections(updates_collections, update)
3461
+ return metric, update
3462
+
3463
+
3464
+ @tf_export('metrics.sparse_precision_at_k')
3465
+ @deprecated(None, 'Use precision_at_k instead')
3466
+ def sparse_precision_at_k(labels,
3467
+ predictions,
3468
+ k,
3469
+ class_id=None,
3470
+ weights=None,
3471
+ metrics_collections=None,
3472
+ updates_collections=None,
3473
+ name=None):
3474
+ """Renamed to `precision_at_k`, please use that method instead."""
3475
+ return precision_at_k(
3476
+ labels=labels,
3477
+ predictions=predictions,
3478
+ k=k,
3479
+ class_id=class_id,
3480
+ weights=weights,
3481
+ metrics_collections=metrics_collections,
3482
+ updates_collections=updates_collections,
3483
+ name=name)
3484
+
3485
+
3486
+ @tf_export('metrics.precision_at_k')
3487
+ def precision_at_k(labels,
3488
+ predictions,
3489
+ k,
3490
+ class_id=None,
3491
+ weights=None,
3492
+ metrics_collections=None,
3493
+ updates_collections=None,
3494
+ name=None):
3495
+ """Computes precision@k of the predictions with respect to sparse labels.
3496
+
3497
+ If `class_id` is specified, we calculate precision by considering only the
3498
+ entries in the batch for which `class_id` is in the top-k highest
3499
+ `predictions`, and computing the fraction of them for which `class_id` is
3500
+ indeed a correct label.
3501
+ If `class_id` is not specified, we'll calculate precision as how often on
3502
+ average a class among the top-k classes with the highest predicted values
3503
+ of a batch entry is correct and can be found in the label for that entry.
3504
+
3505
+ `precision_at_k` creates two local variables,
3506
+ `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3507
+ the precision@k frequency. This frequency is ultimately returned as
3508
+ `precision_at_<k>`: an idempotent operation that simply divides
3509
+ `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3510
+ `false_positive_at_<k>`).
3511
+
3512
+ For estimation of the metric over a stream of data, the function creates an
3513
+ `update_op` operation that updates these variables and returns the
3514
+ `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3515
+ indicating the top `k` `predictions`. Set operations applied to `top_k` and
3516
+ `labels` calculate the true positives and false positives weighted by
3517
+ `weights`. Then `update_op` increments `true_positive_at_<k>` and
3518
+ `false_positive_at_<k>` using these values.
3519
+
3520
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3521
+
3522
+ Args:
3523
+ labels: `int64` `Tensor` or `SparseTensor` with shape
3524
+ [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3525
+ num_labels=1. N >= 1 and num_labels is the number of target classes for
3526
+ the associated prediction. Commonly, N=1 and `labels` has shape
3527
+ [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3528
+ should be in range [0, num_classes), where num_classes is the last
3529
+ dimension of `predictions`. Values outside this range are ignored.
3530
+ predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3531
+ N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3532
+ The final dimension contains the logit values for each class. [D1, ... DN]
3533
+ must match `labels`.
3534
+ k: Integer, k for @k metric.
3535
+ class_id: Integer class ID for which we want binary metrics. This should be
3536
+ in range [0, num_classes], where num_classes is the last dimension of
3537
+ `predictions`. If `class_id` is outside this range, the method returns
3538
+ NAN.
3539
+ weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3540
+ `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3541
+ dimensions must be either `1`, or the same as the corresponding `labels`
3542
+ dimension).
3543
+ metrics_collections: An optional list of collections that values should
3544
+ be added to.
3545
+ updates_collections: An optional list of collections that updates should
3546
+ be added to.
3547
+ name: Name of new update operation, and namespace for other dependent ops.
3548
+
3549
+ Returns:
3550
+ precision: Scalar `float64` `Tensor` with the value of `true_positives`
3551
+ divided by the sum of `true_positives` and `false_positives`.
3552
+ update_op: `Operation` that increments `true_positives` and
3553
+ `false_positives` variables appropriately, and whose value matches
3554
+ `precision`.
3555
+
3556
+ Raises:
3557
+ ValueError: If `weights` is not `None` and its shape doesn't match
3558
+ `predictions`, or if either `metrics_collections` or `updates_collections`
3559
+ are not a list or tuple.
3560
+ RuntimeError: If eager execution is enabled.
3561
+ """
3562
+ if context.executing_eagerly():
3563
+ raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3564
+ 'supported when eager execution is enabled.')
3565
+
3566
+ with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3567
+ (predictions, labels, weights)) as scope:
3568
+ _, top_k_idx = nn.top_k(predictions, k)
3569
+ return precision_at_top_k(
3570
+ labels=labels,
3571
+ predictions_idx=top_k_idx,
3572
+ k=k,
3573
+ class_id=class_id,
3574
+ weights=weights,
3575
+ metrics_collections=metrics_collections,
3576
+ updates_collections=updates_collections,
3577
+ name=scope)
3578
+
3579
+
3580
+ @tf_export('metrics.specificity_at_sensitivity')
3581
+ def specificity_at_sensitivity(labels,
3582
+ predictions,
3583
+ sensitivity,
3584
+ weights=None,
3585
+ num_thresholds=200,
3586
+ metrics_collections=None,
3587
+ updates_collections=None,
3588
+ name=None):
3589
+ """Computes the specificity at a given sensitivity.
3590
+
3591
+ The `specificity_at_sensitivity` function creates four local
3592
+ variables, `true_positives`, `true_negatives`, `false_positives` and
3593
+ `false_negatives` that are used to compute the specificity at the given
3594
+ sensitivity value. The threshold for the given sensitivity value is computed
3595
+ and used to evaluate the corresponding specificity.
3596
+
3597
+ For estimation of the metric over a stream of data, the function creates an
3598
+ `update_op` operation that updates these variables and returns the
3599
+ `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3600
+ `false_positives` and `false_negatives` counts with the weight of each case
3601
+ found in the `predictions` and `labels`.
3602
+
3603
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3604
+
3605
+ For additional information about specificity and sensitivity, see the
3606
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3607
+
3608
+ Args:
3609
+ labels: The ground truth values, a `Tensor` whose dimensions must match
3610
+ `predictions`. Will be cast to `bool`.
3611
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
3612
+ are in the range `[0, 1]`.
3613
+ sensitivity: A scalar value in range `[0, 1]`.
3614
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
3615
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3616
+ be either `1`, or the same as the corresponding `labels` dimension).
3617
+ num_thresholds: The number of thresholds to use for matching the given
3618
+ sensitivity.
3619
+ metrics_collections: An optional list of collections that `specificity`
3620
+ should be added to.
3621
+ updates_collections: An optional list of collections that `update_op` should
3622
+ be added to.
3623
+ name: An optional variable_scope name.
3624
+
3625
+ Returns:
3626
+ specificity: A scalar `Tensor` representing the specificity at the given
3627
+ `specificity` value.
3628
+ update_op: An operation that increments the `true_positives`,
3629
+ `true_negatives`, `false_positives` and `false_negatives` variables
3630
+ appropriately and whose value matches `specificity`.
3631
+
3632
+ Raises:
3633
+ ValueError: If `predictions` and `labels` have mismatched shapes, if
3634
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
3635
+ `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3636
+ or `updates_collections` are not a list or tuple.
3637
+ RuntimeError: If eager execution is enabled.
3638
+ """
3639
+ if context.executing_eagerly():
3640
+ raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3641
+ 'supported when eager execution is enabled.')
3642
+
3643
+ if sensitivity < 0 or sensitivity > 1:
3644
+ raise ValueError('`sensitivity` must be in the range [0, 1].')
3645
+
3646
+ with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3647
+ (predictions, labels, weights)):
3648
+ kepsilon = 1e-7 # to account for floating point imprecisions
3649
+ thresholds = [
3650
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3651
+ ]
3652
+ thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3653
+
3654
+ values, update_ops = _confusion_matrix_at_thresholds(
3655
+ labels, predictions, thresholds, weights)
3656
+
3657
+ def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3658
+ """Computes the specificity at the given sensitivity.
3659
+
3660
+ Args:
3661
+ tp: True positives.
3662
+ tn: True negatives.
3663
+ fp: False positives.
3664
+ fn: False negatives.
3665
+ name: The name of the operation.
3666
+
3667
+ Returns:
3668
+ The specificity using the aggregated values.
3669
+ """
3670
+ sensitivities = math_ops.div(tp, tp + fn + kepsilon)
3671
+
3672
+ # We'll need to use this trick until tf.argmax allows us to specify
3673
+ # whether we should use the first or last index in case of ties.
3674
+ min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3675
+ indices_at_minval = math_ops.equal(
3676
+ math_ops.abs(sensitivities - sensitivity), min_val)
3677
+ indices_at_minval = math_ops.to_int64(indices_at_minval)
3678
+ indices_at_minval = math_ops.cumsum(indices_at_minval)
3679
+ tf_index = math_ops.argmax(indices_at_minval, 0)
3680
+ tf_index = math_ops.cast(tf_index, dtypes.int32)
3681
+
3682
+ # Now, we have the implicit threshold, so compute the specificity:
3683
+ return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
3684
+ name)
3685
+
3686
+ def specificity_across_towers(_, values):
3687
+ return compute_specificity_at_sensitivity(values['tp'], values['tn'],
3688
+ values['fp'], values['fn'],
3689
+ 'value')
3690
+
3691
+ specificity = _aggregate_across_towers(metrics_collections,
3692
+ specificity_across_towers, values)
3693
+
3694
+ update_op = compute_specificity_at_sensitivity(update_ops['tp'],
3695
+ update_ops['tn'],
3696
+ update_ops['fp'],
3697
+ update_ops['fn'],
3698
+ 'update_op')
3699
+ if updates_collections:
3700
+ ops.add_to_collections(updates_collections, update_op)
3701
+
3702
+ return specificity, update_op