tensorbored 2.21.0rc1769983804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (271) hide show
  1. tensorbored/__init__.py +112 -0
  2. tensorbored/_vendor/__init__.py +0 -0
  3. tensorbored/_vendor/bleach/__init__.py +125 -0
  4. tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
  5. tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
  6. tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
  7. tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
  8. tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
  9. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
  10. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
  11. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
  12. tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
  13. tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
  14. tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
  15. tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
  16. tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
  17. tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
  18. tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
  19. tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
  20. tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
  21. tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
  22. tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
  23. tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
  24. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
  25. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
  26. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
  27. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
  28. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
  29. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
  30. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
  31. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
  32. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
  33. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
  34. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
  35. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
  36. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
  37. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
  38. tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
  39. tensorbored/_vendor/bleach/callbacks.py +32 -0
  40. tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
  41. tensorbored/_vendor/bleach/linkifier.py +633 -0
  42. tensorbored/_vendor/bleach/parse_shim.py +1 -0
  43. tensorbored/_vendor/bleach/sanitizer.py +638 -0
  44. tensorbored/_vendor/bleach/six_shim.py +19 -0
  45. tensorbored/_vendor/webencodings/__init__.py +342 -0
  46. tensorbored/_vendor/webencodings/labels.py +231 -0
  47. tensorbored/_vendor/webencodings/mklabels.py +59 -0
  48. tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
  49. tensorbored/assets.py +36 -0
  50. tensorbored/auth.py +102 -0
  51. tensorbored/backend/__init__.py +0 -0
  52. tensorbored/backend/application.py +604 -0
  53. tensorbored/backend/auth_context_middleware.py +38 -0
  54. tensorbored/backend/client_feature_flags.py +113 -0
  55. tensorbored/backend/empty_path_redirect.py +46 -0
  56. tensorbored/backend/event_processing/__init__.py +0 -0
  57. tensorbored/backend/event_processing/data_ingester.py +276 -0
  58. tensorbored/backend/event_processing/data_provider.py +535 -0
  59. tensorbored/backend/event_processing/directory_loader.py +142 -0
  60. tensorbored/backend/event_processing/directory_watcher.py +272 -0
  61. tensorbored/backend/event_processing/event_accumulator.py +950 -0
  62. tensorbored/backend/event_processing/event_file_inspector.py +463 -0
  63. tensorbored/backend/event_processing/event_file_loader.py +292 -0
  64. tensorbored/backend/event_processing/event_multiplexer.py +521 -0
  65. tensorbored/backend/event_processing/event_util.py +68 -0
  66. tensorbored/backend/event_processing/io_wrapper.py +223 -0
  67. tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
  68. tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
  69. tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
  70. tensorbored/backend/event_processing/reservoir.py +266 -0
  71. tensorbored/backend/event_processing/tag_types.py +29 -0
  72. tensorbored/backend/experiment_id.py +71 -0
  73. tensorbored/backend/experimental_plugin.py +51 -0
  74. tensorbored/backend/http_util.py +263 -0
  75. tensorbored/backend/json_util.py +70 -0
  76. tensorbored/backend/path_prefix.py +67 -0
  77. tensorbored/backend/process_graph.py +74 -0
  78. tensorbored/backend/security_validator.py +202 -0
  79. tensorbored/compat/__init__.py +69 -0
  80. tensorbored/compat/proto/__init__.py +0 -0
  81. tensorbored/compat/proto/allocation_description_pb2.py +35 -0
  82. tensorbored/compat/proto/api_def_pb2.py +82 -0
  83. tensorbored/compat/proto/attr_value_pb2.py +80 -0
  84. tensorbored/compat/proto/cluster_pb2.py +58 -0
  85. tensorbored/compat/proto/config_pb2.py +271 -0
  86. tensorbored/compat/proto/coordination_config_pb2.py +45 -0
  87. tensorbored/compat/proto/cost_graph_pb2.py +87 -0
  88. tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
  89. tensorbored/compat/proto/debug_pb2.py +65 -0
  90. tensorbored/compat/proto/event_pb2.py +149 -0
  91. tensorbored/compat/proto/full_type_pb2.py +74 -0
  92. tensorbored/compat/proto/function_pb2.py +157 -0
  93. tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
  94. tensorbored/compat/proto/graph_pb2.py +41 -0
  95. tensorbored/compat/proto/histogram_pb2.py +39 -0
  96. tensorbored/compat/proto/meta_graph_pb2.py +254 -0
  97. tensorbored/compat/proto/node_def_pb2.py +61 -0
  98. tensorbored/compat/proto/op_def_pb2.py +81 -0
  99. tensorbored/compat/proto/resource_handle_pb2.py +48 -0
  100. tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
  101. tensorbored/compat/proto/rpc_options_pb2.py +35 -0
  102. tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
  103. tensorbored/compat/proto/saver_pb2.py +38 -0
  104. tensorbored/compat/proto/step_stats_pb2.py +116 -0
  105. tensorbored/compat/proto/struct_pb2.py +144 -0
  106. tensorbored/compat/proto/summary_pb2.py +111 -0
  107. tensorbored/compat/proto/tensor_description_pb2.py +38 -0
  108. tensorbored/compat/proto/tensor_pb2.py +68 -0
  109. tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
  110. tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
  111. tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
  112. tensorbored/compat/proto/types_pb2.py +105 -0
  113. tensorbored/compat/proto/variable_pb2.py +62 -0
  114. tensorbored/compat/proto/verifier_config_pb2.py +38 -0
  115. tensorbored/compat/proto/versions_pb2.py +35 -0
  116. tensorbored/compat/tensorflow_stub/__init__.py +38 -0
  117. tensorbored/compat/tensorflow_stub/app.py +124 -0
  118. tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
  119. tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
  120. tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
  121. tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
  122. tensorbored/compat/tensorflow_stub/errors.py +507 -0
  123. tensorbored/compat/tensorflow_stub/flags.py +124 -0
  124. tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
  125. tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
  126. tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
  127. tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
  128. tensorbored/context.py +129 -0
  129. tensorbored/data/__init__.py +0 -0
  130. tensorbored/data/grpc_provider.py +365 -0
  131. tensorbored/data/ingester.py +46 -0
  132. tensorbored/data/proto/__init__.py +0 -0
  133. tensorbored/data/proto/data_provider_pb2.py +517 -0
  134. tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
  135. tensorbored/data/provider.py +1365 -0
  136. tensorbored/data/server_ingester.py +301 -0
  137. tensorbored/data_compat.py +159 -0
  138. tensorbored/dataclass_compat.py +224 -0
  139. tensorbored/default.py +124 -0
  140. tensorbored/errors.py +130 -0
  141. tensorbored/lazy.py +99 -0
  142. tensorbored/main.py +48 -0
  143. tensorbored/main_lib.py +62 -0
  144. tensorbored/manager.py +487 -0
  145. tensorbored/notebook.py +441 -0
  146. tensorbored/plugin_util.py +266 -0
  147. tensorbored/plugins/__init__.py +0 -0
  148. tensorbored/plugins/audio/__init__.py +0 -0
  149. tensorbored/plugins/audio/audio_plugin.py +229 -0
  150. tensorbored/plugins/audio/metadata.py +69 -0
  151. tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
  152. tensorbored/plugins/audio/summary.py +230 -0
  153. tensorbored/plugins/audio/summary_v2.py +124 -0
  154. tensorbored/plugins/base_plugin.py +367 -0
  155. tensorbored/plugins/core/__init__.py +0 -0
  156. tensorbored/plugins/core/core_plugin.py +981 -0
  157. tensorbored/plugins/custom_scalar/__init__.py +0 -0
  158. tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
  159. tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
  160. tensorbored/plugins/custom_scalar/metadata.py +35 -0
  161. tensorbored/plugins/custom_scalar/summary.py +79 -0
  162. tensorbored/plugins/debugger_v2/__init__.py +0 -0
  163. tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
  164. tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
  165. tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
  166. tensorbored/plugins/distribution/__init__.py +0 -0
  167. tensorbored/plugins/distribution/compressor.py +158 -0
  168. tensorbored/plugins/distribution/distributions_plugin.py +116 -0
  169. tensorbored/plugins/distribution/metadata.py +19 -0
  170. tensorbored/plugins/graph/__init__.py +0 -0
  171. tensorbored/plugins/graph/graph_util.py +129 -0
  172. tensorbored/plugins/graph/graphs_plugin.py +336 -0
  173. tensorbored/plugins/graph/keras_util.py +328 -0
  174. tensorbored/plugins/graph/metadata.py +42 -0
  175. tensorbored/plugins/histogram/__init__.py +0 -0
  176. tensorbored/plugins/histogram/histograms_plugin.py +144 -0
  177. tensorbored/plugins/histogram/metadata.py +63 -0
  178. tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
  179. tensorbored/plugins/histogram/summary.py +234 -0
  180. tensorbored/plugins/histogram/summary_v2.py +292 -0
  181. tensorbored/plugins/hparams/__init__.py +14 -0
  182. tensorbored/plugins/hparams/_keras.py +93 -0
  183. tensorbored/plugins/hparams/api.py +130 -0
  184. tensorbored/plugins/hparams/api_pb2.py +208 -0
  185. tensorbored/plugins/hparams/backend_context.py +606 -0
  186. tensorbored/plugins/hparams/download_data.py +158 -0
  187. tensorbored/plugins/hparams/error.py +26 -0
  188. tensorbored/plugins/hparams/get_experiment.py +71 -0
  189. tensorbored/plugins/hparams/hparams_plugin.py +206 -0
  190. tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
  191. tensorbored/plugins/hparams/json_format_compat.py +38 -0
  192. tensorbored/plugins/hparams/list_metric_evals.py +57 -0
  193. tensorbored/plugins/hparams/list_session_groups.py +1040 -0
  194. tensorbored/plugins/hparams/metadata.py +125 -0
  195. tensorbored/plugins/hparams/metrics.py +41 -0
  196. tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
  197. tensorbored/plugins/hparams/summary.py +205 -0
  198. tensorbored/plugins/hparams/summary_v2.py +597 -0
  199. tensorbored/plugins/image/__init__.py +0 -0
  200. tensorbored/plugins/image/images_plugin.py +232 -0
  201. tensorbored/plugins/image/metadata.py +65 -0
  202. tensorbored/plugins/image/plugin_data_pb2.py +34 -0
  203. tensorbored/plugins/image/summary.py +159 -0
  204. tensorbored/plugins/image/summary_v2.py +130 -0
  205. tensorbored/plugins/mesh/__init__.py +14 -0
  206. tensorbored/plugins/mesh/mesh_plugin.py +292 -0
  207. tensorbored/plugins/mesh/metadata.py +152 -0
  208. tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
  209. tensorbored/plugins/mesh/summary.py +251 -0
  210. tensorbored/plugins/mesh/summary_v2.py +214 -0
  211. tensorbored/plugins/metrics/__init__.py +0 -0
  212. tensorbored/plugins/metrics/metadata.py +17 -0
  213. tensorbored/plugins/metrics/metrics_plugin.py +623 -0
  214. tensorbored/plugins/pr_curve/__init__.py +0 -0
  215. tensorbored/plugins/pr_curve/metadata.py +75 -0
  216. tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
  217. tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
  218. tensorbored/plugins/pr_curve/summary.py +574 -0
  219. tensorbored/plugins/profile_redirect/__init__.py +0 -0
  220. tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
  221. tensorbored/plugins/projector/__init__.py +67 -0
  222. tensorbored/plugins/projector/metadata.py +26 -0
  223. tensorbored/plugins/projector/projector_config_pb2.py +54 -0
  224. tensorbored/plugins/projector/projector_plugin.py +795 -0
  225. tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
  226. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
  227. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
  228. tensorbored/plugins/scalar/__init__.py +0 -0
  229. tensorbored/plugins/scalar/metadata.py +60 -0
  230. tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
  231. tensorbored/plugins/scalar/scalars_plugin.py +181 -0
  232. tensorbored/plugins/scalar/summary.py +109 -0
  233. tensorbored/plugins/scalar/summary_v2.py +124 -0
  234. tensorbored/plugins/text/__init__.py +0 -0
  235. tensorbored/plugins/text/metadata.py +62 -0
  236. tensorbored/plugins/text/plugin_data_pb2.py +34 -0
  237. tensorbored/plugins/text/summary.py +114 -0
  238. tensorbored/plugins/text/summary_v2.py +124 -0
  239. tensorbored/plugins/text/text_plugin.py +288 -0
  240. tensorbored/plugins/wit_redirect/__init__.py +0 -0
  241. tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
  242. tensorbored/program.py +910 -0
  243. tensorbored/summary/__init__.py +35 -0
  244. tensorbored/summary/_output.py +124 -0
  245. tensorbored/summary/_tf/__init__.py +14 -0
  246. tensorbored/summary/_tf/summary/__init__.py +178 -0
  247. tensorbored/summary/_writer.py +105 -0
  248. tensorbored/summary/v1.py +51 -0
  249. tensorbored/summary/v2.py +25 -0
  250. tensorbored/summary/writer/__init__.py +13 -0
  251. tensorbored/summary/writer/event_file_writer.py +291 -0
  252. tensorbored/summary/writer/record_writer.py +50 -0
  253. tensorbored/util/__init__.py +0 -0
  254. tensorbored/util/encoder.py +116 -0
  255. tensorbored/util/grpc_util.py +311 -0
  256. tensorbored/util/img_mime_type_detector.py +40 -0
  257. tensorbored/util/io_util.py +20 -0
  258. tensorbored/util/lazy_tensor_creator.py +110 -0
  259. tensorbored/util/op_evaluator.py +104 -0
  260. tensorbored/util/platform_util.py +20 -0
  261. tensorbored/util/tb_logging.py +24 -0
  262. tensorbored/util/tensor_util.py +617 -0
  263. tensorbored/util/timing.py +122 -0
  264. tensorbored/version.py +21 -0
  265. tensorbored/webfiles.zip +0 -0
  266. tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
  267. tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
  268. tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
  269. tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
  270. tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
  271. tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
@@ -0,0 +1,574 @@
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Precision--recall curves and TensorFlow operations to create them.
16
+
17
+ NOTE: This module is in beta, and its API is subject to change, but the
18
+ data that it stores to disk will be supported forever.
19
+ """
20
+
21
+ import numpy as np
22
+
23
+ from tensorbored.plugins.pr_curve import metadata
24
+
25
+ # A value that we use as the minimum value during division of counts to prevent
26
+ # division by 0. 1.0 does not work: Certain weights could cause counts below 1.
27
+ _MINIMUM_COUNT = 1e-7
28
+
29
+ # The default number of thresholds.
30
+ _DEFAULT_NUM_THRESHOLDS = 201
31
+
32
+
33
+ def op(
34
+ name,
35
+ labels,
36
+ predictions,
37
+ num_thresholds=None,
38
+ weights=None,
39
+ display_name=None,
40
+ description=None,
41
+ collections=None,
42
+ ):
43
+ """Create a PR curve summary op for a single binary classifier.
44
+
45
+ Computes true/false positive/negative values for the given `predictions`
46
+ against the ground truth `labels`, against a list of evenly distributed
47
+ threshold values in `[0, 1]` of length `num_thresholds`.
48
+
49
+ Each number in `predictions`, a float in `[0, 1]`, is compared with its
50
+ corresponding boolean label in `labels`, and counts as a single tp/fp/tn/fn
51
+ value at each threshold. This is then multiplied with `weights` which can be
52
+ used to reweight certain values, or more commonly used for masking values.
53
+
54
+ Args:
55
+ name: A tag attached to the summary. Used by TensorBoard for organization.
56
+ labels: The ground truth values. A Tensor of `bool` values with arbitrary
57
+ shape.
58
+ predictions: A float32 `Tensor` whose values are in the range `[0, 1]`.
59
+ Dimensions must match those of `labels`.
60
+ num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
61
+ compute PR metrics for. Should be `>= 2`. This value should be a
62
+ constant integer value, not a Tensor that stores an integer.
63
+ weights: Optional float32 `Tensor`. Individual counts are multiplied by this
64
+ value. This tensor must be either the same shape as or broadcastable to
65
+ the `labels` tensor.
66
+ display_name: Optional name for this summary in TensorBoard, as a
67
+ constant `str`. Defaults to `name`.
68
+ description: Optional long-form description for this summary, as a
69
+ constant `str`. Markdown is supported. Defaults to empty.
70
+ collections: Optional list of graph collections keys. The new
71
+ summary op is added to these collections. Defaults to
72
+ `[Graph Keys.SUMMARIES]`.
73
+
74
+ Returns:
75
+ A summary operation for use in a TensorFlow graph. The float32 tensor
76
+ produced by the summary operation is of dimension (6, num_thresholds). The
77
+ first dimension (of length 6) is of the order: true positives,
78
+ false positives, true negatives, false negatives, precision, recall.
79
+ """
80
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
81
+ import tensorflow.compat.v1 as tf
82
+
83
+ if num_thresholds is None:
84
+ num_thresholds = _DEFAULT_NUM_THRESHOLDS
85
+
86
+ if weights is None:
87
+ weights = 1.0
88
+
89
+ dtype = predictions.dtype
90
+
91
+ with tf.name_scope(name, values=[labels, predictions, weights]):
92
+ tf.assert_type(labels, tf.bool)
93
+ # We cast to float to ensure we have 0.0 or 1.0.
94
+ f_labels = tf.cast(labels, dtype)
95
+ # Ensure predictions are all in range [0.0, 1.0].
96
+ predictions = tf.minimum(1.0, tf.maximum(0.0, predictions))
97
+ # Get weighted true/false labels.
98
+ true_labels = f_labels * weights
99
+ false_labels = (1.0 - f_labels) * weights
100
+
101
+ # Before we begin, flatten predictions.
102
+ predictions = tf.reshape(predictions, [-1])
103
+
104
+ # Shape the labels so they are broadcast-able for later multiplication.
105
+ true_labels = tf.reshape(true_labels, [-1, 1])
106
+ false_labels = tf.reshape(false_labels, [-1, 1])
107
+
108
+ # To compute TP/FP/TN/FN, we are measuring a binary classifier
109
+ # C(t) = (predictions >= t)
110
+ # at each threshold 't'. So we have
111
+ # TP(t) = sum( C(t) * true_labels )
112
+ # FP(t) = sum( C(t) * false_labels )
113
+ #
114
+ # But, computing C(t) requires computation for each t. To make it fast,
115
+ # observe that C(t) is a cumulative integral, and so if we have
116
+ # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1}
117
+ # where n = num_thresholds, and if we can compute the bucket function
118
+ # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
119
+ # then we get
120
+ # C(t_i) = sum( B(j), j >= i )
121
+ # which is the reversed cumulative sum in tf.cumsum().
122
+ #
123
+ # We can compute B(i) efficiently by taking advantage of the fact that
124
+ # our thresholds are evenly distributed, in that
125
+ # width = 1.0 / (num_thresholds - 1)
126
+ # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
127
+ # Given a prediction value p, we can map it to its bucket by
128
+ # bucket_index(p) = floor( p * (num_thresholds - 1) )
129
+ # so we can use tf.scatter_add() to update the buckets in one pass.
130
+
131
+ # Compute the bucket indices for each prediction value.
132
+ bucket_indices = tf.cast(
133
+ tf.floor(predictions * (num_thresholds - 1)), tf.int32
134
+ )
135
+
136
+ # Bucket predictions.
137
+ tp_buckets = tf.reduce_sum(
138
+ input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
139
+ * true_labels,
140
+ axis=0,
141
+ )
142
+ fp_buckets = tf.reduce_sum(
143
+ input_tensor=tf.one_hot(bucket_indices, depth=num_thresholds)
144
+ * false_labels,
145
+ axis=0,
146
+ )
147
+
148
+ # Set up the cumulative sums to compute the actual metrics.
149
+ tp = tf.cumsum(tp_buckets, reverse=True, name="tp")
150
+ fp = tf.cumsum(fp_buckets, reverse=True, name="fp")
151
+ # fn = sum(true_labels) - tp
152
+ # = sum(tp_buckets) - tp
153
+ # = tp[0] - tp
154
+ # Similarly,
155
+ # tn = fp[0] - fp
156
+ tn = fp[0] - fp
157
+ fn = tp[0] - tp
158
+
159
+ precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
160
+ recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
161
+
162
+ return _create_tensor_summary(
163
+ name,
164
+ tp,
165
+ fp,
166
+ tn,
167
+ fn,
168
+ precision,
169
+ recall,
170
+ num_thresholds,
171
+ display_name,
172
+ description,
173
+ collections,
174
+ )
175
+
176
+
177
+ def pb(
178
+ name,
179
+ labels,
180
+ predictions,
181
+ num_thresholds=None,
182
+ weights=None,
183
+ display_name=None,
184
+ description=None,
185
+ ):
186
+ """Create a PR curves summary protobuf.
187
+
188
+ Arguments:
189
+ name: A name for the generated node. Will also serve as a series name in
190
+ TensorBoard.
191
+ labels: The ground truth values. A bool numpy array.
192
+ predictions: A float32 numpy array whose values are in the range `[0, 1]`.
193
+ Dimensions must match those of `labels`.
194
+ num_thresholds: Optional number of thresholds, evenly distributed in
195
+ `[0, 1]`, to compute PR metrics for. When provided, should be an int of
196
+ value at least 2. Defaults to 201.
197
+ weights: Optional float or float32 numpy array. Individual counts are
198
+ multiplied by this value. This tensor must be either the same shape as
199
+ or broadcastable to the `labels` numpy array.
200
+ display_name: Optional name for this summary in TensorBoard, as a `str`.
201
+ Defaults to `name`.
202
+ description: Optional long-form description for this summary, as a `str`.
203
+ Markdown is supported. Defaults to empty.
204
+ """
205
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
206
+ import tensorflow.compat.v1 as tf # noqa: F401
207
+
208
+ if num_thresholds is None:
209
+ num_thresholds = _DEFAULT_NUM_THRESHOLDS
210
+
211
+ if weights is None:
212
+ weights = 1.0
213
+
214
+ # Compute bins of true positives and false positives.
215
+ bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
216
+ float_labels = labels.astype(float)
217
+ histogram_range = (0, num_thresholds - 1)
218
+ tp_buckets, _ = np.histogram(
219
+ bucket_indices,
220
+ bins=num_thresholds,
221
+ range=histogram_range,
222
+ weights=float_labels * weights,
223
+ )
224
+ fp_buckets, _ = np.histogram(
225
+ bucket_indices,
226
+ bins=num_thresholds,
227
+ range=histogram_range,
228
+ weights=(1.0 - float_labels) * weights,
229
+ )
230
+
231
+ # Obtain the reverse cumulative sum.
232
+ tp = np.cumsum(tp_buckets[::-1])[::-1]
233
+ fp = np.cumsum(fp_buckets[::-1])[::-1]
234
+ tn = fp[0] - fp
235
+ fn = tp[0] - tp
236
+ precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
237
+ recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
238
+
239
+ return raw_data_pb(
240
+ name,
241
+ true_positive_counts=tp,
242
+ false_positive_counts=fp,
243
+ true_negative_counts=tn,
244
+ false_negative_counts=fn,
245
+ precision=precision,
246
+ recall=recall,
247
+ num_thresholds=num_thresholds,
248
+ display_name=display_name,
249
+ description=description,
250
+ )
251
+
252
+
253
+ def streaming_op(
254
+ name,
255
+ labels,
256
+ predictions,
257
+ num_thresholds=None,
258
+ weights=None,
259
+ metrics_collections=None,
260
+ updates_collections=None,
261
+ display_name=None,
262
+ description=None,
263
+ ):
264
+ """Computes a precision-recall curve summary across batches of data.
265
+
266
+ This function is similar to op() above, but can be used to compute the PR
267
+ curve across multiple batches of labels and predictions, in the same style
268
+ as the metrics found in tf.metrics.
269
+
270
+ This function creates multiple local variables for storing true positives,
271
+ true negative, etc. accumulated over each batch of data, and uses these local
272
+ variables for computing the final PR curve summary. These variables can be
273
+ updated with the returned update_op.
274
+
275
+ Args:
276
+ name: A tag attached to the summary. Used by TensorBoard for organization.
277
+ labels: The ground truth values, a `Tensor` whose dimensions must match
278
+ `predictions`. Will be cast to `bool`.
279
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
280
+ are in the range `[0, 1]`.
281
+ num_thresholds: The number of evenly spaced thresholds to generate for
282
+ computing the PR curve. Defaults to 201.
283
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
284
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
285
+ be either `1`, or the same as the corresponding `labels` dimension).
286
+ metrics_collections: An optional list of collections that `auc` should be
287
+ added to.
288
+ updates_collections: An optional list of collections that `update_op` should
289
+ be added to.
290
+ display_name: Optional name for this summary in TensorBoard, as a
291
+ constant `str`. Defaults to `name`.
292
+ description: Optional long-form description for this summary, as a
293
+ constant `str`. Markdown is supported. Defaults to empty.
294
+
295
+ Returns:
296
+ pr_curve: A string `Tensor` containing a single value: the
297
+ serialized PR curve Tensor summary. The summary contains a
298
+ float32 `Tensor` of dimension (6, num_thresholds). The first
299
+ dimension (of length 6) is of the order: true positives, false
300
+ positives, true negatives, false negatives, precision, recall.
301
+ update_op: An operation that updates the summary with the latest data.
302
+ """
303
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
304
+ import tensorflow.compat.v1 as tf
305
+
306
+ if num_thresholds is None:
307
+ num_thresholds = _DEFAULT_NUM_THRESHOLDS
308
+
309
+ thresholds = [i / float(num_thresholds - 1) for i in range(num_thresholds)]
310
+
311
+ with tf.name_scope(name, values=[labels, predictions, weights]):
312
+ tp, update_tp = tf.metrics.true_positives_at_thresholds(
313
+ labels=labels,
314
+ predictions=predictions,
315
+ thresholds=thresholds,
316
+ weights=weights,
317
+ )
318
+ fp, update_fp = tf.metrics.false_positives_at_thresholds(
319
+ labels=labels,
320
+ predictions=predictions,
321
+ thresholds=thresholds,
322
+ weights=weights,
323
+ )
324
+ tn, update_tn = tf.metrics.true_negatives_at_thresholds(
325
+ labels=labels,
326
+ predictions=predictions,
327
+ thresholds=thresholds,
328
+ weights=weights,
329
+ )
330
+ fn, update_fn = tf.metrics.false_negatives_at_thresholds(
331
+ labels=labels,
332
+ predictions=predictions,
333
+ thresholds=thresholds,
334
+ weights=weights,
335
+ )
336
+
337
+ def compute_summary(tp, fp, tn, fn, collections):
338
+ precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
339
+ recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)
340
+
341
+ return _create_tensor_summary(
342
+ name,
343
+ tp,
344
+ fp,
345
+ tn,
346
+ fn,
347
+ precision,
348
+ recall,
349
+ num_thresholds,
350
+ display_name,
351
+ description,
352
+ collections,
353
+ )
354
+
355
+ pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections)
356
+ update_op = tf.group(update_tp, update_fp, update_tn, update_fn)
357
+ if updates_collections:
358
+ for collection in updates_collections:
359
+ tf.add_to_collection(collection, update_op)
360
+
361
+ return pr_curve, update_op
362
+
363
+
364
+ def raw_data_op(
365
+ name,
366
+ true_positive_counts,
367
+ false_positive_counts,
368
+ true_negative_counts,
369
+ false_negative_counts,
370
+ precision,
371
+ recall,
372
+ num_thresholds=None,
373
+ display_name=None,
374
+ description=None,
375
+ collections=None,
376
+ ):
377
+ """Create an op that collects data for visualizing PR curves.
378
+
379
+ Unlike the op above, this one avoids computing precision, recall, and the
380
+ intermediate counts. Instead, it accepts those tensors as arguments and
381
+ relies on the caller to ensure that the calculations are correct (and the
382
+ counts yield the provided precision and recall values).
383
+
384
+ This op is useful when a caller seeks to compute precision and recall
385
+ differently but still use the PR curves plugin.
386
+
387
+ Args:
388
+ name: A tag attached to the summary. Used by TensorBoard for organization.
389
+ true_positive_counts: A rank-1 tensor of true positive counts. Must contain
390
+ `num_thresholds` elements and be castable to float32. Values correspond
391
+ to thresholds that increase from left to right (from 0 to 1).
392
+ false_positive_counts: A rank-1 tensor of false positive counts. Must
393
+ contain `num_thresholds` elements and be castable to float32. Values
394
+ correspond to thresholds that increase from left to right (from 0 to 1).
395
+ true_negative_counts: A rank-1 tensor of true negative counts. Must contain
396
+ `num_thresholds` elements and be castable to float32. Values
397
+ correspond to thresholds that increase from left to right (from 0 to 1).
398
+ false_negative_counts: A rank-1 tensor of false negative counts. Must
399
+ contain `num_thresholds` elements and be castable to float32. Values
400
+ correspond to thresholds that increase from left to right (from 0 to 1).
401
+ precision: A rank-1 tensor of precision values. Must contain
402
+ `num_thresholds` elements and be castable to float32. Values correspond
403
+ to thresholds that increase from left to right (from 0 to 1).
404
+ recall: A rank-1 tensor of recall values. Must contain `num_thresholds`
405
+ elements and be castable to float32. Values correspond to thresholds
406
+ that increase from left to right (from 0 to 1).
407
+ num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
408
+ compute PR metrics for. Should be `>= 2`. This value should be a
409
+ constant integer value, not a Tensor that stores an integer.
410
+ display_name: Optional name for this summary in TensorBoard, as a
411
+ constant `str`. Defaults to `name`.
412
+ description: Optional long-form description for this summary, as a
413
+ constant `str`. Markdown is supported. Defaults to empty.
414
+ collections: Optional list of graph collections keys. The new
415
+ summary op is added to these collections. Defaults to
416
+ `[Graph Keys.SUMMARIES]`.
417
+
418
+ Returns:
419
+ A summary operation for use in a TensorFlow graph. See docs for the `op`
420
+ method for details on the float32 tensor produced by this summary.
421
+ """
422
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
423
+ import tensorflow.compat.v1 as tf
424
+
425
+ with tf.name_scope(
426
+ name,
427
+ values=[
428
+ true_positive_counts,
429
+ false_positive_counts,
430
+ true_negative_counts,
431
+ false_negative_counts,
432
+ precision,
433
+ recall,
434
+ ],
435
+ ):
436
+ return _create_tensor_summary(
437
+ name,
438
+ true_positive_counts,
439
+ false_positive_counts,
440
+ true_negative_counts,
441
+ false_negative_counts,
442
+ precision,
443
+ recall,
444
+ num_thresholds,
445
+ display_name,
446
+ description,
447
+ collections,
448
+ )
449
+
450
+
451
+ def raw_data_pb(
452
+ name,
453
+ true_positive_counts,
454
+ false_positive_counts,
455
+ true_negative_counts,
456
+ false_negative_counts,
457
+ precision,
458
+ recall,
459
+ num_thresholds=None,
460
+ display_name=None,
461
+ description=None,
462
+ ):
463
+ """Create a PR curves summary protobuf from raw data values.
464
+
465
+ Args:
466
+ name: A tag attached to the summary. Used by TensorBoard for organization.
467
+ true_positive_counts: A rank-1 numpy array of true positive counts. Must
468
+ contain `num_thresholds` elements and be castable to float32.
469
+ false_positive_counts: A rank-1 numpy array of false positive counts. Must
470
+ contain `num_thresholds` elements and be castable to float32.
471
+ true_negative_counts: A rank-1 numpy array of true negative counts. Must
472
+ contain `num_thresholds` elements and be castable to float32.
473
+ false_negative_counts: A rank-1 numpy array of false negative counts. Must
474
+ contain `num_thresholds` elements and be castable to float32.
475
+ precision: A rank-1 numpy array of precision values. Must contain
476
+ `num_thresholds` elements and be castable to float32.
477
+ recall: A rank-1 numpy array of recall values. Must contain `num_thresholds`
478
+ elements and be castable to float32.
479
+ num_thresholds: Number of thresholds, evenly distributed in `[0, 1]`, to
480
+ compute PR metrics for. Should be an int `>= 2`.
481
+ display_name: Optional name for this summary in TensorBoard, as a `str`.
482
+ Defaults to `name`.
483
+ description: Optional long-form description for this summary, as a `str`.
484
+ Markdown is supported. Defaults to empty.
485
+
486
+ Returns:
487
+ A summary operation for use in a TensorFlow graph. See docs for the `op`
488
+ method for details on the float32 tensor produced by this summary.
489
+ """
490
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
491
+ import tensorflow.compat.v1 as tf
492
+
493
+ if display_name is None:
494
+ display_name = name
495
+ summary_metadata = metadata.create_summary_metadata(
496
+ display_name=display_name if display_name is not None else name,
497
+ description=description or "",
498
+ num_thresholds=num_thresholds,
499
+ )
500
+ tf_summary_metadata = tf.SummaryMetadata.FromString(
501
+ summary_metadata.SerializeToString()
502
+ )
503
+ summary = tf.Summary()
504
+ data = np.stack(
505
+ (
506
+ true_positive_counts,
507
+ false_positive_counts,
508
+ true_negative_counts,
509
+ false_negative_counts,
510
+ precision,
511
+ recall,
512
+ )
513
+ )
514
+ tensor = tf.make_tensor_proto(np.float32(data), dtype=tf.float32)
515
+ summary.value.add(
516
+ tag="%s/pr_curves" % name, metadata=tf_summary_metadata, tensor=tensor
517
+ )
518
+ return summary
519
+
520
+
521
+ def _create_tensor_summary(
522
+ name,
523
+ true_positive_counts,
524
+ false_positive_counts,
525
+ true_negative_counts,
526
+ false_negative_counts,
527
+ precision,
528
+ recall,
529
+ num_thresholds=None,
530
+ display_name=None,
531
+ description=None,
532
+ collections=None,
533
+ ):
534
+ """A private helper method for generating a tensor summary.
535
+
536
+ We use a helper method instead of having `op` directly call `raw_data_op`
537
+ to prevent the scope of `raw_data_op` from being embedded within `op`.
538
+
539
+ Arguments are the same as for raw_data_op.
540
+
541
+ Returns:
542
+ A tensor summary that collects data for PR curves.
543
+ """
544
+ # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
545
+ import tensorflow.compat.v1 as tf
546
+
547
+ # Store the number of thresholds within the summary metadata because
548
+ # that value is constant for all pr curve summaries with the same tag.
549
+ summary_metadata = metadata.create_summary_metadata(
550
+ display_name=display_name if display_name is not None else name,
551
+ description=description or "",
552
+ num_thresholds=num_thresholds,
553
+ )
554
+
555
+ # Store values within a tensor. We store them in the order:
556
+ # true positives, false positives, true negatives, false
557
+ # negatives, precision, and recall.
558
+ combined_data = tf.stack(
559
+ [
560
+ tf.cast(true_positive_counts, tf.float32),
561
+ tf.cast(false_positive_counts, tf.float32),
562
+ tf.cast(true_negative_counts, tf.float32),
563
+ tf.cast(false_negative_counts, tf.float32),
564
+ tf.cast(precision, tf.float32),
565
+ tf.cast(recall, tf.float32),
566
+ ]
567
+ )
568
+
569
+ return tf.summary.tensor_summary(
570
+ name="pr_curves",
571
+ tensor=combined_data,
572
+ collections=collections,
573
+ summary_metadata=summary_metadata,
574
+ )
File without changes
@@ -0,0 +1,49 @@
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Plugin that only displays a message with installation instructions."""
16
+
17
+ from tensorbored.plugins import base_plugin
18
+
19
+
20
+ class ProfileRedirectPluginLoader(base_plugin.TBLoader):
21
+ """Load the redirect notice iff the dynamic plugin is unavailable."""
22
+
23
+ def load(self, context):
24
+ try:
25
+ import tensorboard_plugin_profile # noqa: F401
26
+
27
+ # If we successfully load the dynamic plugin, don't show
28
+ # this redirect plugin at all.
29
+ return None
30
+ except ImportError:
31
+ return _ProfileRedirectPlugin(context)
32
+
33
+
34
+ class _ProfileRedirectPlugin(base_plugin.TBPlugin):
35
+ """Redirect notice pointing users to the new dynamic profile plugin."""
36
+
37
+ plugin_name = "profile_redirect"
38
+
39
+ def get_plugin_apps(self):
40
+ return {}
41
+
42
+ def is_active(self):
43
+ return False
44
+
45
+ def frontend_metadata(self):
46
+ return base_plugin.FrontendMetadata(
47
+ element_name="tf-profile-redirect-dashboard",
48
+ tab_name="Profile",
49
+ )
@@ -0,0 +1,67 @@
1
+ # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Public API for the Embedding Projector.
16
+
17
+ @@ProjectorPluginAsset
18
+ @@ProjectorConfig
19
+ @@EmbeddingInfo
20
+ @@EmbeddingMetadata
21
+ @@SpriteMetadata
22
+ """
23
+
24
+ import os
25
+
26
+ from google.protobuf import text_format as _text_format
27
+ from tensorbored.compat import tf
28
+ from tensorbored.plugins.projector import metadata as _metadata
29
+ from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
30
+ EmbeddingInfo,
31
+ )
32
+ from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
33
+ SpriteMetadata,
34
+ )
35
+ from tensorbored.plugins.projector.projector_config_pb2 import ( # noqa: F401
36
+ ProjectorConfig,
37
+ )
38
+
39
+
40
+ def visualize_embeddings(logdir, config):
41
+ """Stores a config file used by the embedding projector.
42
+
43
+ Args:
44
+ logdir: Directory into which to store the config file, as a `str`.
45
+ For compatibility, can also be a `tf.compat.v1.summary.FileWriter`
46
+ object open at the desired logdir.
47
+ config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig`
48
+ proto that holds the configuration for the projector such as paths to
49
+ checkpoint files and metadata files for the embeddings. If
50
+ `config.model_checkpoint_path` is none, it defaults to the
51
+ `logdir` used by the summary_writer.
52
+
53
+ Raises:
54
+ ValueError: If the summary writer does not have a `logdir`.
55
+ """
56
+ # Convert from `tf.compat.v1.summary.FileWriter` if necessary.
57
+ logdir = getattr(logdir, "get_logdir", lambda: logdir)()
58
+
59
+ # Sanity checks.
60
+ if logdir is None:
61
+ raise ValueError("Expected logdir to be a path, but got None")
62
+
63
+ # Saving the config file in the logdir.
64
+ config_pbtxt = _text_format.MessageToString(config)
65
+ path = os.path.join(logdir, _metadata.PROJECTOR_FILENAME)
66
+ with tf.io.gfile.GFile(path, "w") as f:
67
+ f.write(config_pbtxt)