tensorbored 2.21.0rc1769983804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorbored/__init__.py +112 -0
- tensorbored/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/__init__.py +125 -0
- tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
- tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
- tensorbored/_vendor/bleach/callbacks.py +32 -0
- tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
- tensorbored/_vendor/bleach/linkifier.py +633 -0
- tensorbored/_vendor/bleach/parse_shim.py +1 -0
- tensorbored/_vendor/bleach/sanitizer.py +638 -0
- tensorbored/_vendor/bleach/six_shim.py +19 -0
- tensorbored/_vendor/webencodings/__init__.py +342 -0
- tensorbored/_vendor/webencodings/labels.py +231 -0
- tensorbored/_vendor/webencodings/mklabels.py +59 -0
- tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
- tensorbored/assets.py +36 -0
- tensorbored/auth.py +102 -0
- tensorbored/backend/__init__.py +0 -0
- tensorbored/backend/application.py +604 -0
- tensorbored/backend/auth_context_middleware.py +38 -0
- tensorbored/backend/client_feature_flags.py +113 -0
- tensorbored/backend/empty_path_redirect.py +46 -0
- tensorbored/backend/event_processing/__init__.py +0 -0
- tensorbored/backend/event_processing/data_ingester.py +276 -0
- tensorbored/backend/event_processing/data_provider.py +535 -0
- tensorbored/backend/event_processing/directory_loader.py +142 -0
- tensorbored/backend/event_processing/directory_watcher.py +272 -0
- tensorbored/backend/event_processing/event_accumulator.py +950 -0
- tensorbored/backend/event_processing/event_file_inspector.py +463 -0
- tensorbored/backend/event_processing/event_file_loader.py +292 -0
- tensorbored/backend/event_processing/event_multiplexer.py +521 -0
- tensorbored/backend/event_processing/event_util.py +68 -0
- tensorbored/backend/event_processing/io_wrapper.py +223 -0
- tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
- tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
- tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
- tensorbored/backend/event_processing/reservoir.py +266 -0
- tensorbored/backend/event_processing/tag_types.py +29 -0
- tensorbored/backend/experiment_id.py +71 -0
- tensorbored/backend/experimental_plugin.py +51 -0
- tensorbored/backend/http_util.py +263 -0
- tensorbored/backend/json_util.py +70 -0
- tensorbored/backend/path_prefix.py +67 -0
- tensorbored/backend/process_graph.py +74 -0
- tensorbored/backend/security_validator.py +202 -0
- tensorbored/compat/__init__.py +69 -0
- tensorbored/compat/proto/__init__.py +0 -0
- tensorbored/compat/proto/allocation_description_pb2.py +35 -0
- tensorbored/compat/proto/api_def_pb2.py +82 -0
- tensorbored/compat/proto/attr_value_pb2.py +80 -0
- tensorbored/compat/proto/cluster_pb2.py +58 -0
- tensorbored/compat/proto/config_pb2.py +271 -0
- tensorbored/compat/proto/coordination_config_pb2.py +45 -0
- tensorbored/compat/proto/cost_graph_pb2.py +87 -0
- tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
- tensorbored/compat/proto/debug_pb2.py +65 -0
- tensorbored/compat/proto/event_pb2.py +149 -0
- tensorbored/compat/proto/full_type_pb2.py +74 -0
- tensorbored/compat/proto/function_pb2.py +157 -0
- tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
- tensorbored/compat/proto/graph_pb2.py +41 -0
- tensorbored/compat/proto/histogram_pb2.py +39 -0
- tensorbored/compat/proto/meta_graph_pb2.py +254 -0
- tensorbored/compat/proto/node_def_pb2.py +61 -0
- tensorbored/compat/proto/op_def_pb2.py +81 -0
- tensorbored/compat/proto/resource_handle_pb2.py +48 -0
- tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
- tensorbored/compat/proto/rpc_options_pb2.py +35 -0
- tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
- tensorbored/compat/proto/saver_pb2.py +38 -0
- tensorbored/compat/proto/step_stats_pb2.py +116 -0
- tensorbored/compat/proto/struct_pb2.py +144 -0
- tensorbored/compat/proto/summary_pb2.py +111 -0
- tensorbored/compat/proto/tensor_description_pb2.py +38 -0
- tensorbored/compat/proto/tensor_pb2.py +68 -0
- tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
- tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
- tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
- tensorbored/compat/proto/types_pb2.py +105 -0
- tensorbored/compat/proto/variable_pb2.py +62 -0
- tensorbored/compat/proto/verifier_config_pb2.py +38 -0
- tensorbored/compat/proto/versions_pb2.py +35 -0
- tensorbored/compat/tensorflow_stub/__init__.py +38 -0
- tensorbored/compat/tensorflow_stub/app.py +124 -0
- tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
- tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
- tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
- tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
- tensorbored/compat/tensorflow_stub/errors.py +507 -0
- tensorbored/compat/tensorflow_stub/flags.py +124 -0
- tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
- tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
- tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
- tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
- tensorbored/context.py +129 -0
- tensorbored/data/__init__.py +0 -0
- tensorbored/data/grpc_provider.py +365 -0
- tensorbored/data/ingester.py +46 -0
- tensorbored/data/proto/__init__.py +0 -0
- tensorbored/data/proto/data_provider_pb2.py +517 -0
- tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
- tensorbored/data/provider.py +1365 -0
- tensorbored/data/server_ingester.py +301 -0
- tensorbored/data_compat.py +159 -0
- tensorbored/dataclass_compat.py +224 -0
- tensorbored/default.py +124 -0
- tensorbored/errors.py +130 -0
- tensorbored/lazy.py +99 -0
- tensorbored/main.py +48 -0
- tensorbored/main_lib.py +62 -0
- tensorbored/manager.py +487 -0
- tensorbored/notebook.py +441 -0
- tensorbored/plugin_util.py +266 -0
- tensorbored/plugins/__init__.py +0 -0
- tensorbored/plugins/audio/__init__.py +0 -0
- tensorbored/plugins/audio/audio_plugin.py +229 -0
- tensorbored/plugins/audio/metadata.py +69 -0
- tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
- tensorbored/plugins/audio/summary.py +230 -0
- tensorbored/plugins/audio/summary_v2.py +124 -0
- tensorbored/plugins/base_plugin.py +367 -0
- tensorbored/plugins/core/__init__.py +0 -0
- tensorbored/plugins/core/core_plugin.py +981 -0
- tensorbored/plugins/custom_scalar/__init__.py +0 -0
- tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
- tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
- tensorbored/plugins/custom_scalar/metadata.py +35 -0
- tensorbored/plugins/custom_scalar/summary.py +79 -0
- tensorbored/plugins/debugger_v2/__init__.py +0 -0
- tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
- tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
- tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
- tensorbored/plugins/distribution/__init__.py +0 -0
- tensorbored/plugins/distribution/compressor.py +158 -0
- tensorbored/plugins/distribution/distributions_plugin.py +116 -0
- tensorbored/plugins/distribution/metadata.py +19 -0
- tensorbored/plugins/graph/__init__.py +0 -0
- tensorbored/plugins/graph/graph_util.py +129 -0
- tensorbored/plugins/graph/graphs_plugin.py +336 -0
- tensorbored/plugins/graph/keras_util.py +328 -0
- tensorbored/plugins/graph/metadata.py +42 -0
- tensorbored/plugins/histogram/__init__.py +0 -0
- tensorbored/plugins/histogram/histograms_plugin.py +144 -0
- tensorbored/plugins/histogram/metadata.py +63 -0
- tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
- tensorbored/plugins/histogram/summary.py +234 -0
- tensorbored/plugins/histogram/summary_v2.py +292 -0
- tensorbored/plugins/hparams/__init__.py +14 -0
- tensorbored/plugins/hparams/_keras.py +93 -0
- tensorbored/plugins/hparams/api.py +130 -0
- tensorbored/plugins/hparams/api_pb2.py +208 -0
- tensorbored/plugins/hparams/backend_context.py +606 -0
- tensorbored/plugins/hparams/download_data.py +158 -0
- tensorbored/plugins/hparams/error.py +26 -0
- tensorbored/plugins/hparams/get_experiment.py +71 -0
- tensorbored/plugins/hparams/hparams_plugin.py +206 -0
- tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
- tensorbored/plugins/hparams/json_format_compat.py +38 -0
- tensorbored/plugins/hparams/list_metric_evals.py +57 -0
- tensorbored/plugins/hparams/list_session_groups.py +1040 -0
- tensorbored/plugins/hparams/metadata.py +125 -0
- tensorbored/plugins/hparams/metrics.py +41 -0
- tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
- tensorbored/plugins/hparams/summary.py +205 -0
- tensorbored/plugins/hparams/summary_v2.py +597 -0
- tensorbored/plugins/image/__init__.py +0 -0
- tensorbored/plugins/image/images_plugin.py +232 -0
- tensorbored/plugins/image/metadata.py +65 -0
- tensorbored/plugins/image/plugin_data_pb2.py +34 -0
- tensorbored/plugins/image/summary.py +159 -0
- tensorbored/plugins/image/summary_v2.py +130 -0
- tensorbored/plugins/mesh/__init__.py +14 -0
- tensorbored/plugins/mesh/mesh_plugin.py +292 -0
- tensorbored/plugins/mesh/metadata.py +152 -0
- tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
- tensorbored/plugins/mesh/summary.py +251 -0
- tensorbored/plugins/mesh/summary_v2.py +214 -0
- tensorbored/plugins/metrics/__init__.py +0 -0
- tensorbored/plugins/metrics/metadata.py +17 -0
- tensorbored/plugins/metrics/metrics_plugin.py +623 -0
- tensorbored/plugins/pr_curve/__init__.py +0 -0
- tensorbored/plugins/pr_curve/metadata.py +75 -0
- tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
- tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
- tensorbored/plugins/pr_curve/summary.py +574 -0
- tensorbored/plugins/profile_redirect/__init__.py +0 -0
- tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
- tensorbored/plugins/projector/__init__.py +67 -0
- tensorbored/plugins/projector/metadata.py +26 -0
- tensorbored/plugins/projector/projector_config_pb2.py +54 -0
- tensorbored/plugins/projector/projector_plugin.py +795 -0
- tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
- tensorbored/plugins/scalar/__init__.py +0 -0
- tensorbored/plugins/scalar/metadata.py +60 -0
- tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
- tensorbored/plugins/scalar/scalars_plugin.py +181 -0
- tensorbored/plugins/scalar/summary.py +109 -0
- tensorbored/plugins/scalar/summary_v2.py +124 -0
- tensorbored/plugins/text/__init__.py +0 -0
- tensorbored/plugins/text/metadata.py +62 -0
- tensorbored/plugins/text/plugin_data_pb2.py +34 -0
- tensorbored/plugins/text/summary.py +114 -0
- tensorbored/plugins/text/summary_v2.py +124 -0
- tensorbored/plugins/text/text_plugin.py +288 -0
- tensorbored/plugins/wit_redirect/__init__.py +0 -0
- tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
- tensorbored/program.py +910 -0
- tensorbored/summary/__init__.py +35 -0
- tensorbored/summary/_output.py +124 -0
- tensorbored/summary/_tf/__init__.py +14 -0
- tensorbored/summary/_tf/summary/__init__.py +178 -0
- tensorbored/summary/_writer.py +105 -0
- tensorbored/summary/v1.py +51 -0
- tensorbored/summary/v2.py +25 -0
- tensorbored/summary/writer/__init__.py +13 -0
- tensorbored/summary/writer/event_file_writer.py +291 -0
- tensorbored/summary/writer/record_writer.py +50 -0
- tensorbored/util/__init__.py +0 -0
- tensorbored/util/encoder.py +116 -0
- tensorbored/util/grpc_util.py +311 -0
- tensorbored/util/img_mime_type_detector.py +40 -0
- tensorbored/util/io_util.py +20 -0
- tensorbored/util/lazy_tensor_creator.py +110 -0
- tensorbored/util/op_evaluator.py +104 -0
- tensorbored/util/platform_util.py +20 -0
- tensorbored/util/tb_logging.py +24 -0
- tensorbored/util/tensor_util.py +617 -0
- tensorbored/util/timing.py +122 -0
- tensorbored/version.py +21 -0
- tensorbored/webfiles.zip +0 -0
- tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
- tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
- tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
- tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
- tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
- tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Precision--recall curves and TensorFlow operations to create them.
|
|
16
|
+
|
|
17
|
+
NOTE: This module is in beta, and its API is subject to change, but the
|
|
18
|
+
data that it stores to disk will be supported forever.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from tensorbored.plugins.pr_curve import metadata
|
|
24
|
+
|
|
25
|
+
# A value that we use as the minimum value during division of counts to prevent
|
|
26
|
+
# division by 0. 1.0 does not work: Certain weights could cause counts below 1.
|
|
27
|
+
_MINIMUM_COUNT = 1e-7
|
|
28
|
+
|
|
29
|
+
# The default number of thresholds.
|
|
30
|
+
_DEFAULT_NUM_THRESHOLDS = 201
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def op(
|
|
34
|
+
name,
|
|
35
|
+
labels,
|
|
36
|
+
predictions,
|
|
37
|
+
num_thresholds=None,
|
|
38
|
+
weights=None,
|
|
39
|
+
display_name=None,
|
|
40
|
+
description=None,
|
|
41
|
+
collections=None,
|
|
42
|
+
):
|
|
43
|
+
"""Create a PR curve summary op for a single binary classifier.
|
|
44
|
+
|
|
45
|
+
Computes true/false positive/negative values for the given `predictions`
|
|
46
|
+
against the ground truth `labels`, against a list of evenly distributed
|
|
47
|
+
threshold values in `[0, 1]` of length `num_thresholds`.
|
|
48
|
+
|
|
49
|
+
Each number in `predictions`, a float in `[0, 1]`, is compared with its
|
|
50
|
+
corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn
|
|
51
|
+
value at each threshold. This is then multiplied with `weights` which can be
|
|
52
|
+
used to reweight certain values, or more commonly used for masking values.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
name: A tag attached to the summary. Used by TensorBoard for organization.
|
|
56
|
+
labels: The ground truth values. A Tensor of `bool` values with arbitrary
|
|
57
|
+
shape.
|
|
58
|
+
predictions: A float32 `Tensor` whose values are in the range `[0, 1]`.
|
|
59
|
+
Dimensions must match those of `labels`.
|
|
60
|
+
num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
|
|
61
|
+
compute PR metrics for. Should be `>= 2`. This value should be a
|
|
62
|
+
constant integer value, not a Tensor that stores an integer.
|
|
63
|
+
weights: Optional float32 `Tensor`. Individual counts are multiplied by this
|
|
64
|
+
value. This tensor must be either the same shape as or broadcastable to
|
|
65
|
+
the `labels` tensor.
|
|
66
|
+
display_name: Optional name for this summary in TensorBoard, as a
|
|
67
|
+
constant `str`. Defaults to `name`.
|
|
68
|
+
description: Optional long-form description for this summary, as a
|
|
69
|
+
constant `str`. Markdown is supported. Defaults to empty.
|
|
70
|
+
collections: Optional list of graph collections keys. The new
|
|
71
|
+
summary op is added to these collections. Defaults to
|
|
72
|
+
`[Graph Keys.SUMMARIES]`.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A summary operation for use in a TensorFlow graph. The float32 tensor
|
|
76
|
+
produced by the summary operation is of dimension (6, num_thresholds). The
|
|
77
|
+
first dimension (of length 6) is of the order: true positives,
|
|
78
|
+
false positives, true negatives, false negatives, precision, recall.
|
|
79
|
+
"""
|
|
80
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
81
|
+
import tensorflow.compat.v1 as tf
|
|
82
|
+
|
|
83
|
+
if num_thresholds is None:
|
|
84
|
+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
|
|
85
|
+
|
|
86
|
+
if weights is None:
|
|
87
|
+
weights = 1.0
|
|
88
|
+
|
|
89
|
+
dtype = predictions.dtype
|
|
90
|
+
|
|
91
|
+
with tf.name_scope(name, values=[labels, predictions, weights]):
|
|
92
|
+
tf.assert_type(labels, tf.bool)
|
|
93
|
+
# We cast to float to ensure we have 0.0 or 1.0.
|
|
94
|
+
f_labels = tf.cast(labels, dtype)
|
|
95
|
+
# Ensure predictions are all in range [0.0, 1.0].
|
|
96
|
+
predictions = tf.minimum(1.0, tf.maximum(0.0, predictions))
|
|
97
|
+
# Get weighted true/false labels.
|
|
98
|
+
true_labels = f_labels * weights
|
|
99
|
+
false_labels = (1.0 - f_labels) * weights
|
|
100
|
+
|
|
101
|
+
# Before we begin, flatten predictions.
|
|
102
|
+
predictions = tf.reshape(predictions, [-1])
|
|
103
|
+
|
|
104
|
+
# Shape the labels so they are broadcast-able for later multiplication.
|
|
105
|
+
true_labels = tf.reshape(true_labels, [-1, 1])
|
|
106
|
+
false_labels = tf.reshape(false_labels, [-1, 1])
|
|
107
|
+
|
|
108
|
+
# To compute TP/FP/TN/FN, we are measuring a binary classifier
|
|
109
|
+
# C(t) = (predictions >= t)
|
|
110
|
+
# at each threshold 't'. So we have
|
|
111
|
+
# TP(t) = sum( C(t) * true_labels )
|
|
112
|
+
# FP(t) = sum( C(t) * false_labels )
|
|
113
|
+
#
|
|
114
|
+
# But, computing C(t) requires computation for each t. To make it fast,
|
|
115
|
+
# observe that C(t) is a cumulative integral, and so if we have
|
|
116
|
+
# thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
|
|
117
|
+
# where n = num_thresholds, and if we can compute the bucket function
|
|
118
|
+
# B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
|
|
119
|
+
# then we get
|
|
120
|
+
# C(t_i) = sum( B(j), j >= i )
|
|
121
|
+
# which is the reversed cumulative sum in tf.cumsum().
|
|
122
|
+
#
|
|
123
|
+
# We can compute B(i) efficiently by taking advantage of the fact that
|
|
124
|
+
# our thresholds are evenly distributed, in that
|
|
125
|
+
# width = 1.0 / (num_thresholds - 1)
|
|
126
|
+
# thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
|
|
127
|
+
# Given a prediction value p, we can map it to its bucket by
|
|
128
|
+
# bucket_index(p) = floor( p * (num_thresholds - 1) )
|
|
129
|
+
# so we can use tf.scatter_add() to update the buckets in one pass.
|
|
130
|
+
|
|
131
|
+
# Compute the bucket indices for each prediction value.
|
|
132
|
+
bucket_indices = tf.cast(
|
|
133
|
+
tf.floor(predictions * (num_thresholds - 1)), tf.int32
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Bucket predictions.
|
|
137
|
+
tp_buckets = tf.reduce_sum(
|
|
138
|
+
input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
|
|
139
|
+
* true_labels,
|
|
140
|
+
axis=0,
|
|
141
|
+
)
|
|
142
|
+
fp_buckets = tf.reduce_sum(
|
|
143
|
+
input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
|
|
144
|
+
* false_labels,
|
|
145
|
+
axis=0,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Set up the cumulative sums to compute the actual metrics.
|
|
149
|
+
tp = tf.cumsum(tp_buckets, reverse=True, name="tp")
|
|
150
|
+
fp = tf.cumsum(fp_buckets, reverse=True, name="fp")
|
|
151
|
+
# fn = sum(true_labels) - tp
|
|
152
|
+
# = sum(tp_buckets) - tp
|
|
153
|
+
# = tp[0] - tp
|
|
154
|
+
# Similarly,
|
|
155
|
+
# tn = fp[0] - fp
|
|
156
|
+
tn = fp[0] - fp
|
|
157
|
+
fn = tp[0] - tp
|
|
158
|
+
|
|
159
|
+
precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
|
|
160
|
+
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
|
|
161
|
+
|
|
162
|
+
return _create_tensor_summary(
|
|
163
|
+
name,
|
|
164
|
+
tp,
|
|
165
|
+
fp,
|
|
166
|
+
tn,
|
|
167
|
+
fn,
|
|
168
|
+
precision,
|
|
169
|
+
recall,
|
|
170
|
+
num_thresholds,
|
|
171
|
+
display_name,
|
|
172
|
+
description,
|
|
173
|
+
collections,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def pb(
|
|
178
|
+
name,
|
|
179
|
+
labels,
|
|
180
|
+
predictions,
|
|
181
|
+
num_thresholds=None,
|
|
182
|
+
weights=None,
|
|
183
|
+
display_name=None,
|
|
184
|
+
description=None,
|
|
185
|
+
):
|
|
186
|
+
"""Create a PR curves summary protobuf.
|
|
187
|
+
|
|
188
|
+
Arguments:
|
|
189
|
+
name: A name for the generated node. Will also serve as a series name in
|
|
190
|
+
TensorBoard.
|
|
191
|
+
labels: The ground truth values. A bool numpy array.
|
|
192
|
+
predictions: A float32 numpy array whose values are in the range `[0, 1]`.
|
|
193
|
+
Dimensions must match those of `labels`.
|
|
194
|
+
num_thresholds: Optional number of thresholds, evenly distributed in
|
|
195
|
+
`[0, 1]`, to compute PR metrics for. When provided, should be an int of
|
|
196
|
+
value at least 2. Defaults to 201.
|
|
197
|
+
weights: Optional float or float32 numpy array. Individual counts are
|
|
198
|
+
multiplied by this value. This tensor must be either the same shape as
|
|
199
|
+
or broadcastable to the `labels` numpy array.
|
|
200
|
+
display_name: Optional name for this summary in TensorBoard, as a `str`.
|
|
201
|
+
Defaults to `name`.
|
|
202
|
+
description: Optional long-form description for this summary, as a `str`.
|
|
203
|
+
Markdown is supported. Defaults to empty.
|
|
204
|
+
"""
|
|
205
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
206
|
+
import tensorflow.compat.v1 as tf # noqa: F401
|
|
207
|
+
|
|
208
|
+
if num_thresholds is None:
|
|
209
|
+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
|
|
210
|
+
|
|
211
|
+
if weights is None:
|
|
212
|
+
weights = 1.0
|
|
213
|
+
|
|
214
|
+
# Compute bins of true positives and false positives.
|
|
215
|
+
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
|
|
216
|
+
float_labels = labels.astype(float)
|
|
217
|
+
histogram_range = (0, num_thresholds - 1)
|
|
218
|
+
tp_buckets, _ = np.histogram(
|
|
219
|
+
bucket_indices,
|
|
220
|
+
bins=num_thresholds,
|
|
221
|
+
range=histogram_range,
|
|
222
|
+
weights=float_labels * weights,
|
|
223
|
+
)
|
|
224
|
+
fp_buckets, _ = np.histogram(
|
|
225
|
+
bucket_indices,
|
|
226
|
+
bins=num_thresholds,
|
|
227
|
+
range=histogram_range,
|
|
228
|
+
weights=(1.0 - float_labels) * weights,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Obtain the reverse cumulative sum.
|
|
232
|
+
tp = np.cumsum(tp_buckets[::-1])[::-1]
|
|
233
|
+
fp = np.cumsum(fp_buckets[::-1])[::-1]
|
|
234
|
+
tn = fp[0] - fp
|
|
235
|
+
fn = tp[0] - tp
|
|
236
|
+
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
|
|
237
|
+
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
|
|
238
|
+
|
|
239
|
+
return raw_data_pb(
|
|
240
|
+
name,
|
|
241
|
+
true_positive_counts=tp,
|
|
242
|
+
false_positive_counts=fp,
|
|
243
|
+
true_negative_counts=tn,
|
|
244
|
+
false_negative_counts=fn,
|
|
245
|
+
precision=precision,
|
|
246
|
+
recall=recall,
|
|
247
|
+
num_thresholds=num_thresholds,
|
|
248
|
+
display_name=display_name,
|
|
249
|
+
description=description,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def streaming_op(
|
|
254
|
+
name,
|
|
255
|
+
labels,
|
|
256
|
+
predictions,
|
|
257
|
+
num_thresholds=None,
|
|
258
|
+
weights=None,
|
|
259
|
+
metrics_collections=None,
|
|
260
|
+
updates_collections=None,
|
|
261
|
+
display_name=None,
|
|
262
|
+
description=None,
|
|
263
|
+
):
|
|
264
|
+
"""Computes a precision-recall curve summary across batches of data.
|
|
265
|
+
|
|
266
|
+
This function is similar to op() above, but can be used to compute the PR
|
|
267
|
+
curve across multiple batches of labels and predictions, in the same style
|
|
268
|
+
as the metrics found in tf.metrics.
|
|
269
|
+
|
|
270
|
+
This function creates multiple local variables for storing true positives,
|
|
271
|
+
true negative, etc. accumulated over each batch of data, and uses these local
|
|
272
|
+
variables for computing the final PR curve summary. These variables can be
|
|
273
|
+
updated with the returned update_op.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
name: A tag attached to the summary. Used by TensorBoard for organization.
|
|
277
|
+
labels: The ground truth values, a `Tensor` whose dimensions must match
|
|
278
|
+
`predictions`. Will be cast to `bool`.
|
|
279
|
+
predictions: A floating point `Tensor` of arbitrary shape and whose values
|
|
280
|
+
are in the range `[0, 1]`.
|
|
281
|
+
num_thresholds: The number of evenly spaced thresholds to generate for
|
|
282
|
+
computing the PR curve. Defaults to 201.
|
|
283
|
+
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
|
284
|
+
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
|
285
|
+
be either `1`, or the same as the corresponding `labels` dimension).
|
|
286
|
+
metrics_collections: An optional list of collections that `auc` should be
|
|
287
|
+
added to.
|
|
288
|
+
updates_collections: An optional list of collections that `update_op` should
|
|
289
|
+
be added to.
|
|
290
|
+
display_name: Optional name for this summary in TensorBoard, as a
|
|
291
|
+
constant `str`. Defaults to `name`.
|
|
292
|
+
description: Optional long-form description for this summary, as a
|
|
293
|
+
constant `str`. Markdown is supported. Defaults to empty.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
pr_curve: A string `Tensor` containing a single value: the
|
|
297
|
+
serialized PR curve Tensor summary. The summary contains a
|
|
298
|
+
float32 `Tensor` of dimension (6, num_thresholds). The first
|
|
299
|
+
dimension (of length 6) is of the order: true positives, false
|
|
300
|
+
positives, true negatives, false negatives, precision, recall.
|
|
301
|
+
update_op: An operation that updates the summary with the latest data.
|
|
302
|
+
"""
|
|
303
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
304
|
+
import tensorflow.compat.v1 as tf
|
|
305
|
+
|
|
306
|
+
if num_thresholds is None:
|
|
307
|
+
num_thresholds = _DEFAULT_NUM_THRESHOLDS
|
|
308
|
+
|
|
309
|
+
thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)]
|
|
310
|
+
|
|
311
|
+
with tf.name_scope(name, values=[labels, predictions, weights]):
|
|
312
|
+
tp, update_tp = tf.metrics.true_positives_at_thresholds(
|
|
313
|
+
labels=labels,
|
|
314
|
+
predictions=predictions,
|
|
315
|
+
thresholds=thresholds,
|
|
316
|
+
weights=weights,
|
|
317
|
+
)
|
|
318
|
+
fp, update_fp = tf.metrics.false_positives_at_thresholds(
|
|
319
|
+
labels=labels,
|
|
320
|
+
predictions=predictions,
|
|
321
|
+
thresholds=thresholds,
|
|
322
|
+
weights=weights,
|
|
323
|
+
)
|
|
324
|
+
tn, update_tn = tf.metrics.true_negatives_at_thresholds(
|
|
325
|
+
labels=labels,
|
|
326
|
+
predictions=predictions,
|
|
327
|
+
thresholds=thresholds,
|
|
328
|
+
weights=weights,
|
|
329
|
+
)
|
|
330
|
+
fn, update_fn = tf.metrics.false_negatives_at_thresholds(
|
|
331
|
+
labels=labels,
|
|
332
|
+
predictions=predictions,
|
|
333
|
+
thresholds=thresholds,
|
|
334
|
+
weights=weights,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def compute_summary(tp, fp, tn, fn, collections):
|
|
338
|
+
precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
|
|
339
|
+
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
|
|
340
|
+
|
|
341
|
+
return _create_tensor_summary(
|
|
342
|
+
name,
|
|
343
|
+
tp,
|
|
344
|
+
fp,
|
|
345
|
+
tn,
|
|
346
|
+
fn,
|
|
347
|
+
precision,
|
|
348
|
+
recall,
|
|
349
|
+
num_thresholds,
|
|
350
|
+
display_name,
|
|
351
|
+
description,
|
|
352
|
+
collections,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections)
|
|
356
|
+
update_op = tf.group(update_tp, update_fp, update_tn, update_fn)
|
|
357
|
+
if updates_collections:
|
|
358
|
+
for collection in updates_collections:
|
|
359
|
+
tf.add_to_collection(collection, update_op)
|
|
360
|
+
|
|
361
|
+
return pr_curve, update_op
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def raw_data_op(
|
|
365
|
+
name,
|
|
366
|
+
true_positive_counts,
|
|
367
|
+
false_positive_counts,
|
|
368
|
+
true_negative_counts,
|
|
369
|
+
false_negative_counts,
|
|
370
|
+
precision,
|
|
371
|
+
recall,
|
|
372
|
+
num_thresholds=None,
|
|
373
|
+
display_name=None,
|
|
374
|
+
description=None,
|
|
375
|
+
collections=None,
|
|
376
|
+
):
|
|
377
|
+
"""Create an op that collects data for visualizing PR curves.
|
|
378
|
+
|
|
379
|
+
Unlike the op above, this one avoids computing precision, recall, and the
|
|
380
|
+
intermediate counts. Instead, it accepts those tensors as arguments and
|
|
381
|
+
relies on the caller to ensure that the calculations are correct (and the
|
|
382
|
+
counts yield the provided precision and recall values).
|
|
383
|
+
|
|
384
|
+
This op is useful when a caller seeks to compute precision and recall
|
|
385
|
+
differently but still use the PR curves plugin.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
name: A tag attached to the summary. Used by TensorBoard for organization.
|
|
389
|
+
true_positive_counts: A rank-1 tensor of true positive counts. Must contain
|
|
390
|
+
`num_thresholds` elements and be castable to float32. Values correspond
|
|
391
|
+
to thresholds that increase from left to right (from 0 to 1).
|
|
392
|
+
false_positive_counts: A rank-1 tensor of false positive counts. Must
|
|
393
|
+
contain `num_thresholds` elements and be castable to float32. Values
|
|
394
|
+
correspond to thresholds that increase from left to right (from 0 to 1).
|
|
395
|
+
true_negative_counts: A rank-1 tensor of true negative counts. Must contain
|
|
396
|
+
`num_thresholds` elements and be castable to float32. Values
|
|
397
|
+
correspond to thresholds that increase from left to right (from 0 to 1).
|
|
398
|
+
false_negative_counts: A rank-1 tensor of false negative counts. Must
|
|
399
|
+
contain `num_thresholds` elements and be castable to float32. Values
|
|
400
|
+
correspond to thresholds that increase from left to right (from 0 to 1).
|
|
401
|
+
precision: A rank-1 tensor of precision values. Must contain
|
|
402
|
+
`num_thresholds` elements and be castable to float32. Values correspond
|
|
403
|
+
to thresholds that increase from left to right (from 0 to 1).
|
|
404
|
+
recall: A rank-1 tensor of recall values. Must contain `num_thresholds`
|
|
405
|
+
elements and be castable to float32. Values correspond to thresholds
|
|
406
|
+
that increase from left to right (from 0 to 1).
|
|
407
|
+
num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
|
|
408
|
+
compute PR metrics for. Should be `>= 2`. This value should be a
|
|
409
|
+
constant integer value, not a Tensor that stores an integer.
|
|
410
|
+
display_name: Optional name for this summary in TensorBoard, as a
|
|
411
|
+
constant `str`. Defaults to `name`.
|
|
412
|
+
description: Optional long-form description for this summary, as a
|
|
413
|
+
constant `str`. Markdown is supported. Defaults to empty.
|
|
414
|
+
collections: Optional list of graph collections keys. The new
|
|
415
|
+
summary op is added to these collections. Defaults to
|
|
416
|
+
`[Graph Keys.SUMMARIES]`.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
A summary operation for use in a TensorFlow graph. See docs for the `op`
|
|
420
|
+
method for details on the float32 tensor produced by this summary.
|
|
421
|
+
"""
|
|
422
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
423
|
+
import tensorflow.compat.v1 as tf
|
|
424
|
+
|
|
425
|
+
with tf.name_scope(
|
|
426
|
+
name,
|
|
427
|
+
values=[
|
|
428
|
+
true_positive_counts,
|
|
429
|
+
false_positive_counts,
|
|
430
|
+
true_negative_counts,
|
|
431
|
+
false_negative_counts,
|
|
432
|
+
precision,
|
|
433
|
+
recall,
|
|
434
|
+
],
|
|
435
|
+
):
|
|
436
|
+
return _create_tensor_summary(
|
|
437
|
+
name,
|
|
438
|
+
true_positive_counts,
|
|
439
|
+
false_positive_counts,
|
|
440
|
+
true_negative_counts,
|
|
441
|
+
false_negative_counts,
|
|
442
|
+
precision,
|
|
443
|
+
recall,
|
|
444
|
+
num_thresholds,
|
|
445
|
+
display_name,
|
|
446
|
+
description,
|
|
447
|
+
collections,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def raw_data_pb(
|
|
452
|
+
name,
|
|
453
|
+
true_positive_counts,
|
|
454
|
+
false_positive_counts,
|
|
455
|
+
true_negative_counts,
|
|
456
|
+
false_negative_counts,
|
|
457
|
+
precision,
|
|
458
|
+
recall,
|
|
459
|
+
num_thresholds=None,
|
|
460
|
+
display_name=None,
|
|
461
|
+
description=None,
|
|
462
|
+
):
|
|
463
|
+
"""Create a PR curves summary protobuf from raw data values.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
name: A tag attached to the summary. Used by TensorBoard for organization.
|
|
467
|
+
true_positive_counts: A rank-1 numpy array of true positive counts. Must
|
|
468
|
+
contain `num_thresholds` elements and be castable to float32.
|
|
469
|
+
false_positive_counts: A rank-1 numpy array of false positive counts. Must
|
|
470
|
+
contain `num_thresholds` elements and be castable to float32.
|
|
471
|
+
true_negative_counts: A rank-1 numpy array of true negative counts. Must
|
|
472
|
+
contain `num_thresholds` elements and be castable to float32.
|
|
473
|
+
false_negative_counts: A rank-1 numpy array of false negative counts. Must
|
|
474
|
+
contain `num_thresholds` elements and be castable to float32.
|
|
475
|
+
precision: A rank-1 numpy array of precision values. Must contain
|
|
476
|
+
`num_thresholds` elements and be castable to float32.
|
|
477
|
+
recall: A rank-1 numpy array of recall values. Must contain `num_thresholds`
|
|
478
|
+
elements and be castable to float32.
|
|
479
|
+
num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
|
|
480
|
+
compute PR metrics for. Should be an int `>= 2`.
|
|
481
|
+
display_name: Optional name for this summary in TensorBoard, as a `str`.
|
|
482
|
+
Defaults to `name`.
|
|
483
|
+
description: Optional long-form description for this summary, as a `str`.
|
|
484
|
+
Markdown is supported. Defaults to empty.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
A summary operation for use in a TensorFlow graph. See docs for the `op`
|
|
488
|
+
method for details on the float32 tensor produced by this summary.
|
|
489
|
+
"""
|
|
490
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
491
|
+
import tensorflow.compat.v1 as tf
|
|
492
|
+
|
|
493
|
+
if display_name is None:
|
|
494
|
+
display_name = name
|
|
495
|
+
summary_metadata = metadata.create_summary_metadata(
|
|
496
|
+
display_name=display_name if display_name is not None else name,
|
|
497
|
+
description=description or "",
|
|
498
|
+
num_thresholds=num_thresholds,
|
|
499
|
+
)
|
|
500
|
+
tf_summary_metadata = tf.SummaryMetadata.FromString(
|
|
501
|
+
summary_metadata.SerializeToString()
|
|
502
|
+
)
|
|
503
|
+
summary = tf.Summary()
|
|
504
|
+
data = np.stack(
|
|
505
|
+
(
|
|
506
|
+
true_positive_counts,
|
|
507
|
+
false_positive_counts,
|
|
508
|
+
true_negative_counts,
|
|
509
|
+
false_negative_counts,
|
|
510
|
+
precision,
|
|
511
|
+
recall,
|
|
512
|
+
)
|
|
513
|
+
)
|
|
514
|
+
tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32)
|
|
515
|
+
summary.value.add(
|
|
516
|
+
tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor
|
|
517
|
+
)
|
|
518
|
+
return summary
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def _create_tensor_summary(
|
|
522
|
+
name,
|
|
523
|
+
true_positive_counts,
|
|
524
|
+
false_positive_counts,
|
|
525
|
+
true_negative_counts,
|
|
526
|
+
false_negative_counts,
|
|
527
|
+
precision,
|
|
528
|
+
recall,
|
|
529
|
+
num_thresholds=None,
|
|
530
|
+
display_name=None,
|
|
531
|
+
description=None,
|
|
532
|
+
collections=None,
|
|
533
|
+
):
|
|
534
|
+
"""A private helper method for generating a tensor summary.
|
|
535
|
+
|
|
536
|
+
We use a helper method instead of having `op` directly call `raw_data_op`
|
|
537
|
+
to prevent the scope of `raw_data_op` from being embedded within `op`.
|
|
538
|
+
|
|
539
|
+
Arguments are the same as for raw_data_op.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
A tensor summary that collects data for PR curves.
|
|
543
|
+
"""
|
|
544
|
+
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
|
|
545
|
+
import tensorflow.compat.v1 as tf
|
|
546
|
+
|
|
547
|
+
# Store the number of thresholds within the summary metadata because
|
|
548
|
+
# that value is constant for all pr curve summaries with the same tag.
|
|
549
|
+
summary_metadata = metadata.create_summary_metadata(
|
|
550
|
+
display_name=display_name if display_name is not None else name,
|
|
551
|
+
description=description or "",
|
|
552
|
+
num_thresholds=num_thresholds,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Store values within a tensor. We store them in the order:
|
|
556
|
+
# true positives, false positives, true negatives, false
|
|
557
|
+
# negatives, precision, and recall.
|
|
558
|
+
combined_data = tf.stack(
|
|
559
|
+
[
|
|
560
|
+
tf.cast(true_positive_counts, tf.float32),
|
|
561
|
+
tf.cast(false_positive_counts, tf.float32),
|
|
562
|
+
tf.cast(true_negative_counts, tf.float32),
|
|
563
|
+
tf.cast(false_negative_counts, tf.float32),
|
|
564
|
+
tf.cast(precision, tf.float32),
|
|
565
|
+
tf.cast(recall, tf.float32),
|
|
566
|
+
]
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
return tf.summary.tensor_summary(
|
|
570
|
+
name="pr_curves",
|
|
571
|
+
tensor=combined_data,
|
|
572
|
+
collections=collections,
|
|
573
|
+
summary_metadata=summary_metadata,
|
|
574
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Plugin that only displays a message with installation instructions."""
|
|
16
|
+
|
|
17
|
+
from tensorbored.plugins import base_plugin
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProfileRedirectPluginLoader(base_plugin.TBLoader):
|
|
21
|
+
"""Load the redirect notice iff the dynamic plugin is unavailable."""
|
|
22
|
+
|
|
23
|
+
def load(self, context):
|
|
24
|
+
try:
|
|
25
|
+
import tensorboard_plugin_profile # noqa: F401
|
|
26
|
+
|
|
27
|
+
# If we successfully load the dynamic plugin, don't show
|
|
28
|
+
# this redirect plugin at all.
|
|
29
|
+
return None
|
|
30
|
+
except ImportError:
|
|
31
|
+
return _ProfileRedirectPlugin(context)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _ProfileRedirectPlugin(base_plugin.TBPlugin):
|
|
35
|
+
"""Redirect notice pointing users to the new dynamic profile plugin."""
|
|
36
|
+
|
|
37
|
+
plugin_name = "profile_redirect"
|
|
38
|
+
|
|
39
|
+
def get_plugin_apps(self):
|
|
40
|
+
return {}
|
|
41
|
+
|
|
42
|
+
def is_active(self):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
def frontend_metadata(self):
|
|
46
|
+
return base_plugin.FrontendMetadata(
|
|
47
|
+
element_name="tf-profile-redirect-dashboard",
|
|
48
|
+
tab_name="Profile",
|
|
49
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Public API for the Embedding Projector.
|
|
16
|
+
|
|
17
|
+
@@ProjectorPluginAsset
|
|
18
|
+
@@ProjectorConfig
|
|
19
|
+
@@EmbeddingInfo
|
|
20
|
+
@@EmbeddingMetadata
|
|
21
|
+
@@SpriteMetadata
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import os
|
|
25
|
+
|
|
26
|
+
from google.protobuf import text_format as _text_format
|
|
27
|
+
from tensorbored.compat import tf
|
|
28
|
+
from tensorbored.plugins.projector import metadata as _metadata
|
|
29
|
+
from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
|
|
30
|
+
EmbeddingInfo,
|
|
31
|
+
)
|
|
32
|
+
from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
|
|
33
|
+
SpriteMetadata,
|
|
34
|
+
)
|
|
35
|
+
from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
|
|
36
|
+
ProjectorConfig,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def visualize_embeddings(logdir, config):
|
|
41
|
+
"""Stores a config file used by the embedding projector.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
logdir: Directory into which to store the config file, as a `str`.
|
|
45
|
+
For compatibility, can also be a `tf.compat.v1.summary.FileWriter`
|
|
46
|
+
object open at the desired logdir.
|
|
47
|
+
config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig`
|
|
48
|
+
proto that holds the configuration for the projector such as paths to
|
|
49
|
+
checkpoint files and metadata files for the embeddings. If
|
|
50
|
+
`config.model_checkpoint_path` is none, it defaults to the
|
|
51
|
+
`logdir` used by the summary_writer.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If the summary writer does not have a `logdir`.
|
|
55
|
+
"""
|
|
56
|
+
# Convert from `tf.compat.v1.summary.FileWriter` if necessary.
|
|
57
|
+
logdir = getattr(logdir, "get_logdir", lambda: logdir)()
|
|
58
|
+
|
|
59
|
+
# Sanity checks.
|
|
60
|
+
if logdir is None:
|
|
61
|
+
raise ValueError("Expected logdir to be a path, but got None")
|
|
62
|
+
|
|
63
|
+
# Saving the config file in the logdir.
|
|
64
|
+
config_pbtxt = _text_format.MessageToString(config)
|
|
65
|
+
path = os.path.join(logdir, _metadata.PROJECTOR_FILENAME)
|
|
66
|
+
with tf.io.gfile.GFile(path, "w") as f:
|
|
67
|
+
f.write(config_pbtxt)
|