tensorbored 2.21.0rc1769983804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (271) hide show
  1. tensorbored/__init__.py +112 -0
  2. tensorbored/_vendor/__init__.py +0 -0
  3. tensorbored/_vendor/bleach/__init__.py +125 -0
  4. tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
  5. tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
  6. tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
  7. tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
  8. tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
  9. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
  10. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
  11. tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
  12. tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
  13. tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
  14. tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
  15. tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
  16. tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
  17. tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
  18. tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
  19. tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
  20. tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
  21. tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
  22. tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
  23. tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
  24. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
  25. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
  26. tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
  27. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
  28. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
  29. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
  30. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
  31. tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
  32. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
  33. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
  34. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
  35. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
  36. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
  37. tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
  38. tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
  39. tensorbored/_vendor/bleach/callbacks.py +32 -0
  40. tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
  41. tensorbored/_vendor/bleach/linkifier.py +633 -0
  42. tensorbored/_vendor/bleach/parse_shim.py +1 -0
  43. tensorbored/_vendor/bleach/sanitizer.py +638 -0
  44. tensorbored/_vendor/bleach/six_shim.py +19 -0
  45. tensorbored/_vendor/webencodings/__init__.py +342 -0
  46. tensorbored/_vendor/webencodings/labels.py +231 -0
  47. tensorbored/_vendor/webencodings/mklabels.py +59 -0
  48. tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
  49. tensorbored/assets.py +36 -0
  50. tensorbored/auth.py +102 -0
  51. tensorbored/backend/__init__.py +0 -0
  52. tensorbored/backend/application.py +604 -0
  53. tensorbored/backend/auth_context_middleware.py +38 -0
  54. tensorbored/backend/client_feature_flags.py +113 -0
  55. tensorbored/backend/empty_path_redirect.py +46 -0
  56. tensorbored/backend/event_processing/__init__.py +0 -0
  57. tensorbored/backend/event_processing/data_ingester.py +276 -0
  58. tensorbored/backend/event_processing/data_provider.py +535 -0
  59. tensorbored/backend/event_processing/directory_loader.py +142 -0
  60. tensorbored/backend/event_processing/directory_watcher.py +272 -0
  61. tensorbored/backend/event_processing/event_accumulator.py +950 -0
  62. tensorbored/backend/event_processing/event_file_inspector.py +463 -0
  63. tensorbored/backend/event_processing/event_file_loader.py +292 -0
  64. tensorbored/backend/event_processing/event_multiplexer.py +521 -0
  65. tensorbored/backend/event_processing/event_util.py +68 -0
  66. tensorbored/backend/event_processing/io_wrapper.py +223 -0
  67. tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
  68. tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
  69. tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
  70. tensorbored/backend/event_processing/reservoir.py +266 -0
  71. tensorbored/backend/event_processing/tag_types.py +29 -0
  72. tensorbored/backend/experiment_id.py +71 -0
  73. tensorbored/backend/experimental_plugin.py +51 -0
  74. tensorbored/backend/http_util.py +263 -0
  75. tensorbored/backend/json_util.py +70 -0
  76. tensorbored/backend/path_prefix.py +67 -0
  77. tensorbored/backend/process_graph.py +74 -0
  78. tensorbored/backend/security_validator.py +202 -0
  79. tensorbored/compat/__init__.py +69 -0
  80. tensorbored/compat/proto/__init__.py +0 -0
  81. tensorbored/compat/proto/allocation_description_pb2.py +35 -0
  82. tensorbored/compat/proto/api_def_pb2.py +82 -0
  83. tensorbored/compat/proto/attr_value_pb2.py +80 -0
  84. tensorbored/compat/proto/cluster_pb2.py +58 -0
  85. tensorbored/compat/proto/config_pb2.py +271 -0
  86. tensorbored/compat/proto/coordination_config_pb2.py +45 -0
  87. tensorbored/compat/proto/cost_graph_pb2.py +87 -0
  88. tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
  89. tensorbored/compat/proto/debug_pb2.py +65 -0
  90. tensorbored/compat/proto/event_pb2.py +149 -0
  91. tensorbored/compat/proto/full_type_pb2.py +74 -0
  92. tensorbored/compat/proto/function_pb2.py +157 -0
  93. tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
  94. tensorbored/compat/proto/graph_pb2.py +41 -0
  95. tensorbored/compat/proto/histogram_pb2.py +39 -0
  96. tensorbored/compat/proto/meta_graph_pb2.py +254 -0
  97. tensorbored/compat/proto/node_def_pb2.py +61 -0
  98. tensorbored/compat/proto/op_def_pb2.py +81 -0
  99. tensorbored/compat/proto/resource_handle_pb2.py +48 -0
  100. tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
  101. tensorbored/compat/proto/rpc_options_pb2.py +35 -0
  102. tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
  103. tensorbored/compat/proto/saver_pb2.py +38 -0
  104. tensorbored/compat/proto/step_stats_pb2.py +116 -0
  105. tensorbored/compat/proto/struct_pb2.py +144 -0
  106. tensorbored/compat/proto/summary_pb2.py +111 -0
  107. tensorbored/compat/proto/tensor_description_pb2.py +38 -0
  108. tensorbored/compat/proto/tensor_pb2.py +68 -0
  109. tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
  110. tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
  111. tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
  112. tensorbored/compat/proto/types_pb2.py +105 -0
  113. tensorbored/compat/proto/variable_pb2.py +62 -0
  114. tensorbored/compat/proto/verifier_config_pb2.py +38 -0
  115. tensorbored/compat/proto/versions_pb2.py +35 -0
  116. tensorbored/compat/tensorflow_stub/__init__.py +38 -0
  117. tensorbored/compat/tensorflow_stub/app.py +124 -0
  118. tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
  119. tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
  120. tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
  121. tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
  122. tensorbored/compat/tensorflow_stub/errors.py +507 -0
  123. tensorbored/compat/tensorflow_stub/flags.py +124 -0
  124. tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
  125. tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
  126. tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
  127. tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
  128. tensorbored/context.py +129 -0
  129. tensorbored/data/__init__.py +0 -0
  130. tensorbored/data/grpc_provider.py +365 -0
  131. tensorbored/data/ingester.py +46 -0
  132. tensorbored/data/proto/__init__.py +0 -0
  133. tensorbored/data/proto/data_provider_pb2.py +517 -0
  134. tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
  135. tensorbored/data/provider.py +1365 -0
  136. tensorbored/data/server_ingester.py +301 -0
  137. tensorbored/data_compat.py +159 -0
  138. tensorbored/dataclass_compat.py +224 -0
  139. tensorbored/default.py +124 -0
  140. tensorbored/errors.py +130 -0
  141. tensorbored/lazy.py +99 -0
  142. tensorbored/main.py +48 -0
  143. tensorbored/main_lib.py +62 -0
  144. tensorbored/manager.py +487 -0
  145. tensorbored/notebook.py +441 -0
  146. tensorbored/plugin_util.py +266 -0
  147. tensorbored/plugins/__init__.py +0 -0
  148. tensorbored/plugins/audio/__init__.py +0 -0
  149. tensorbored/plugins/audio/audio_plugin.py +229 -0
  150. tensorbored/plugins/audio/metadata.py +69 -0
  151. tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
  152. tensorbored/plugins/audio/summary.py +230 -0
  153. tensorbored/plugins/audio/summary_v2.py +124 -0
  154. tensorbored/plugins/base_plugin.py +367 -0
  155. tensorbored/plugins/core/__init__.py +0 -0
  156. tensorbored/plugins/core/core_plugin.py +981 -0
  157. tensorbored/plugins/custom_scalar/__init__.py +0 -0
  158. tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
  159. tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
  160. tensorbored/plugins/custom_scalar/metadata.py +35 -0
  161. tensorbored/plugins/custom_scalar/summary.py +79 -0
  162. tensorbored/plugins/debugger_v2/__init__.py +0 -0
  163. tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
  164. tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
  165. tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
  166. tensorbored/plugins/distribution/__init__.py +0 -0
  167. tensorbored/plugins/distribution/compressor.py +158 -0
  168. tensorbored/plugins/distribution/distributions_plugin.py +116 -0
  169. tensorbored/plugins/distribution/metadata.py +19 -0
  170. tensorbored/plugins/graph/__init__.py +0 -0
  171. tensorbored/plugins/graph/graph_util.py +129 -0
  172. tensorbored/plugins/graph/graphs_plugin.py +336 -0
  173. tensorbored/plugins/graph/keras_util.py +328 -0
  174. tensorbored/plugins/graph/metadata.py +42 -0
  175. tensorbored/plugins/histogram/__init__.py +0 -0
  176. tensorbored/plugins/histogram/histograms_plugin.py +144 -0
  177. tensorbored/plugins/histogram/metadata.py +63 -0
  178. tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
  179. tensorbored/plugins/histogram/summary.py +234 -0
  180. tensorbored/plugins/histogram/summary_v2.py +292 -0
  181. tensorbored/plugins/hparams/__init__.py +14 -0
  182. tensorbored/plugins/hparams/_keras.py +93 -0
  183. tensorbored/plugins/hparams/api.py +130 -0
  184. tensorbored/plugins/hparams/api_pb2.py +208 -0
  185. tensorbored/plugins/hparams/backend_context.py +606 -0
  186. tensorbored/plugins/hparams/download_data.py +158 -0
  187. tensorbored/plugins/hparams/error.py +26 -0
  188. tensorbored/plugins/hparams/get_experiment.py +71 -0
  189. tensorbored/plugins/hparams/hparams_plugin.py +206 -0
  190. tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
  191. tensorbored/plugins/hparams/json_format_compat.py +38 -0
  192. tensorbored/plugins/hparams/list_metric_evals.py +57 -0
  193. tensorbored/plugins/hparams/list_session_groups.py +1040 -0
  194. tensorbored/plugins/hparams/metadata.py +125 -0
  195. tensorbored/plugins/hparams/metrics.py +41 -0
  196. tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
  197. tensorbored/plugins/hparams/summary.py +205 -0
  198. tensorbored/plugins/hparams/summary_v2.py +597 -0
  199. tensorbored/plugins/image/__init__.py +0 -0
  200. tensorbored/plugins/image/images_plugin.py +232 -0
  201. tensorbored/plugins/image/metadata.py +65 -0
  202. tensorbored/plugins/image/plugin_data_pb2.py +34 -0
  203. tensorbored/plugins/image/summary.py +159 -0
  204. tensorbored/plugins/image/summary_v2.py +130 -0
  205. tensorbored/plugins/mesh/__init__.py +14 -0
  206. tensorbored/plugins/mesh/mesh_plugin.py +292 -0
  207. tensorbored/plugins/mesh/metadata.py +152 -0
  208. tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
  209. tensorbored/plugins/mesh/summary.py +251 -0
  210. tensorbored/plugins/mesh/summary_v2.py +214 -0
  211. tensorbored/plugins/metrics/__init__.py +0 -0
  212. tensorbored/plugins/metrics/metadata.py +17 -0
  213. tensorbored/plugins/metrics/metrics_plugin.py +623 -0
  214. tensorbored/plugins/pr_curve/__init__.py +0 -0
  215. tensorbored/plugins/pr_curve/metadata.py +75 -0
  216. tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
  217. tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
  218. tensorbored/plugins/pr_curve/summary.py +574 -0
  219. tensorbored/plugins/profile_redirect/__init__.py +0 -0
  220. tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
  221. tensorbored/plugins/projector/__init__.py +67 -0
  222. tensorbored/plugins/projector/metadata.py +26 -0
  223. tensorbored/plugins/projector/projector_config_pb2.py +54 -0
  224. tensorbored/plugins/projector/projector_plugin.py +795 -0
  225. tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
  226. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
  227. tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
  228. tensorbored/plugins/scalar/__init__.py +0 -0
  229. tensorbored/plugins/scalar/metadata.py +60 -0
  230. tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
  231. tensorbored/plugins/scalar/scalars_plugin.py +181 -0
  232. tensorbored/plugins/scalar/summary.py +109 -0
  233. tensorbored/plugins/scalar/summary_v2.py +124 -0
  234. tensorbored/plugins/text/__init__.py +0 -0
  235. tensorbored/plugins/text/metadata.py +62 -0
  236. tensorbored/plugins/text/plugin_data_pb2.py +34 -0
  237. tensorbored/plugins/text/summary.py +114 -0
  238. tensorbored/plugins/text/summary_v2.py +124 -0
  239. tensorbored/plugins/text/text_plugin.py +288 -0
  240. tensorbored/plugins/wit_redirect/__init__.py +0 -0
  241. tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
  242. tensorbored/program.py +910 -0
  243. tensorbored/summary/__init__.py +35 -0
  244. tensorbored/summary/_output.py +124 -0
  245. tensorbored/summary/_tf/__init__.py +14 -0
  246. tensorbored/summary/_tf/summary/__init__.py +178 -0
  247. tensorbored/summary/_writer.py +105 -0
  248. tensorbored/summary/v1.py +51 -0
  249. tensorbored/summary/v2.py +25 -0
  250. tensorbored/summary/writer/__init__.py +13 -0
  251. tensorbored/summary/writer/event_file_writer.py +291 -0
  252. tensorbored/summary/writer/record_writer.py +50 -0
  253. tensorbored/util/__init__.py +0 -0
  254. tensorbored/util/encoder.py +116 -0
  255. tensorbored/util/grpc_util.py +311 -0
  256. tensorbored/util/img_mime_type_detector.py +40 -0
  257. tensorbored/util/io_util.py +20 -0
  258. tensorbored/util/lazy_tensor_creator.py +110 -0
  259. tensorbored/util/op_evaluator.py +104 -0
  260. tensorbored/util/platform_util.py +20 -0
  261. tensorbored/util/tb_logging.py +24 -0
  262. tensorbored/util/tensor_util.py +617 -0
  263. tensorbored/util/timing.py +122 -0
  264. tensorbored/version.py +21 -0
  265. tensorbored/webfiles.zip +0 -0
  266. tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
  267. tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
  268. tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
  269. tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
  270. tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
  271. tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1040 @@
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Classes and functions for handling the ListSessionGroups API call."""
16
+
17
+ import collections
18
+ import dataclasses
19
+ import operator
20
+ import re
21
+ from typing import Optional
22
+
23
+ from google.protobuf import struct_pb2
24
+
25
+ from tensorbored.data import provider
26
+ from tensorbored.plugins.hparams import api_pb2
27
+ from tensorbored.plugins.hparams import backend_context as backend_context_lib
28
+ from tensorbored.plugins.hparams import error
29
+ from tensorbored.plugins.hparams import json_format_compat
30
+ from tensorbored.plugins.hparams import metadata
31
+ from tensorbored.plugins.hparams import metrics
32
+ from tensorbored.plugins.hparams import plugin_data_pb2
33
+
34
+
35
+ class Handler:
36
+ """Handles a ListSessionGroups request."""
37
+
38
+ def __init__(
39
+ self, request_context, backend_context, experiment_id, request
40
+ ):
41
+ """Constructor.
42
+
43
+ Args:
44
+ request_context: A tensorboard.context.RequestContext.
45
+ backend_context: A backend_context.Context instance.
46
+ experiment_id: A string, as from `plugin_util.experiment_id`.
47
+ request: A ListSessionGroupsRequest protobuf.
48
+ """
49
+ self._request_context = request_context
50
+ self._backend_context = backend_context
51
+ self._experiment_id = experiment_id
52
+ self._request = request
53
+ self._include_metrics = (
54
+ # Metrics are included by default if include_metrics is not
55
+ # specified in the request.
56
+ not self._request.HasField("include_metrics")
57
+ or self._request.include_metrics
58
+ )
59
+
60
+ def run(self):
61
+ """Handles the request specified on construction.
62
+
63
+ This operation first attempts to construct SessionGroup information
64
+ from hparam tags metadata.EXPERIMENT_TAG and
65
+ metadata.SESSION_START_INFO.
66
+
67
+ If no such tags are found, then will build SessionGroup information
68
+ using the results from DataProvider.read_hyperparameters().
69
+
70
+ Returns:
71
+ A ListSessionGroupsResponse object.
72
+ """
73
+
74
+ session_groups_from_tags = self._session_groups_from_tags()
75
+ if session_groups_from_tags:
76
+ return self._create_response(session_groups_from_tags)
77
+
78
+ session_groups_from_data_provider = (
79
+ self._session_groups_from_data_provider()
80
+ )
81
+ if session_groups_from_data_provider:
82
+ return self._create_response(session_groups_from_data_provider)
83
+
84
+ return api_pb2.ListSessionGroupsResponse(
85
+ session_groups=[], total_size=0
86
+ )
87
+
88
+ def _session_groups_from_tags(self):
89
+ """Constructs lists of SessionGroups based on hparam tag metadata."""
90
+ # Query for all Hparams summary metadata one time to minimize calls to
91
+ # the underlying DataProvider.
92
+ hparams_run_to_tag_to_content = self._backend_context.hparams_metadata(
93
+ self._request_context, self._experiment_id
94
+ )
95
+ # Construct the experiment one time since an context.experiment() call
96
+ # may search through all the runs.
97
+ experiment = self._backend_context.experiment_from_metadata(
98
+ self._request_context,
99
+ self._experiment_id,
100
+ self._include_metrics,
101
+ hparams_run_to_tag_to_content,
102
+ # Don't pass any information from the DataProvider since we are only
103
+ # examining session groups based on tag metadata
104
+ provider.ListHyperparametersResult(
105
+ hyperparameters=[], session_groups=[]
106
+ ),
107
+ )
108
+ extractors = _create_extractors(self._request.col_params)
109
+ filters = _create_filters(self._request.col_params, extractors)
110
+
111
+ session_groups = self._build_session_groups(
112
+ hparams_run_to_tag_to_content, experiment.metric_infos
113
+ )
114
+ session_groups = self._filter(session_groups, filters)
115
+ self._sort(session_groups, extractors)
116
+
117
+ if _specifies_include(self._request.col_params):
118
+ _reduce_to_hparams_to_include(
119
+ session_groups, self._request.col_params
120
+ )
121
+
122
+ return session_groups
123
+
124
+ def _session_groups_from_data_provider(self):
125
+ """Constructs lists of SessionGroups based on DataProvider results."""
126
+ filters = _build_data_provider_filters(self._request.col_params)
127
+ sort = _build_data_provider_sort(self._request.col_params)
128
+ hparams_to_include = (
129
+ _get_hparams_to_include(self._request.col_params)
130
+ if _specifies_include(self._request.col_params)
131
+ else None
132
+ )
133
+ response = self._backend_context.session_groups_from_data_provider(
134
+ self._request_context,
135
+ self._experiment_id,
136
+ filters,
137
+ sort,
138
+ hparams_to_include,
139
+ )
140
+
141
+ metric_infos = (
142
+ self._backend_context.compute_metric_infos_from_data_provider_session_groups(
143
+ self._request_context, self._experiment_id, response
144
+ )
145
+ if self._include_metrics
146
+ else []
147
+ )
148
+
149
+ all_metric_evals = (
150
+ self._backend_context.read_last_scalars(
151
+ self._request_context,
152
+ self._experiment_id,
153
+ run_tag_filter=None,
154
+ )
155
+ if self._include_metrics
156
+ else {}
157
+ )
158
+
159
+ session_groups = []
160
+ for provider_group in response:
161
+ sessions = []
162
+ for session in provider_group.sessions:
163
+ session_name = (
164
+ backend_context_lib.generate_data_provider_session_name(
165
+ session
166
+ )
167
+ )
168
+ sessions.append(
169
+ self._build_session(
170
+ metric_infos,
171
+ session_name,
172
+ plugin_data_pb2.SessionStartInfo(),
173
+ plugin_data_pb2.SessionEndInfo(),
174
+ all_metric_evals,
175
+ )
176
+ )
177
+
178
+ name = backend_context_lib.generate_data_provider_session_name(
179
+ provider_group.root
180
+ )
181
+ if not name:
182
+ name = self._experiment_id
183
+ session_group = api_pb2.SessionGroup(
184
+ name=name,
185
+ sessions=sessions,
186
+ )
187
+
188
+ for provider_hparam in provider_group.hyperparameter_values:
189
+ hparam = session_group.hparams[
190
+ provider_hparam.hyperparameter_name
191
+ ]
192
+ if (
193
+ provider_hparam.domain_type
194
+ == provider.HyperparameterDomainType.DISCRETE_STRING
195
+ ):
196
+ hparam.string_value = provider_hparam.value
197
+ elif provider_hparam.domain_type in [
198
+ provider.HyperparameterDomainType.DISCRETE_FLOAT,
199
+ provider.HyperparameterDomainType.INTERVAL,
200
+ ]:
201
+ hparam.number_value = provider_hparam.value
202
+ elif (
203
+ provider_hparam.domain_type
204
+ == provider.HyperparameterDomainType.DISCRETE_BOOL
205
+ ):
206
+ hparam.bool_value = provider_hparam.value
207
+
208
+ session_groups.append(session_group)
209
+
210
+ # Compute the session group's aggregated metrics for each group.
211
+ for group in session_groups:
212
+ if group.sessions:
213
+ self._aggregate_metrics(group)
214
+
215
+ extractors = _create_extractors(self._request.col_params)
216
+ filters = _create_filters(
217
+ self._request.col_params,
218
+ extractors,
219
+ # We assume the DataProvider will apply hparam filters and we do not
220
+ # attempt to reapply them.
221
+ include_hparam_filters=False,
222
+ )
223
+ session_groups = self._filter(session_groups, filters)
224
+ return session_groups
225
+
226
+ def _build_session_groups(
227
+ self, hparams_run_to_tag_to_content, metric_infos
228
+ ):
229
+ """Returns a list of SessionGroups protobuffers from the summary
230
+ data."""
231
+
232
+ # Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
233
+ # (str) to a SessionGroup protobuffer. We traverse the runs associated with
234
+ # the plugin--each representing a single session. We form a Session
235
+ # protobuffer from each run and add it to the relevant SessionGroup object
236
+ # in the 'groups_by_name' dict. We create the SessionGroup object, if this
237
+ # is the first session of that group we encounter.
238
+ groups_by_name = {}
239
+ # The TensorBoard runs with session start info are the
240
+ # "sessions", which are not necessarily the runs that actually
241
+ # contain metrics (may be in subdirectories).
242
+ session_names = [
243
+ run
244
+ for (run, tags) in hparams_run_to_tag_to_content.items()
245
+ if metadata.SESSION_START_INFO_TAG in tags
246
+ ]
247
+ metric_runs = set()
248
+ metric_tags = set()
249
+ for session_name in session_names:
250
+ for metric in metric_infos:
251
+ metric_name = metric.name
252
+ run, tag = metrics.run_tag_from_session_and_metric(
253
+ session_name, metric_name
254
+ )
255
+ metric_runs.add(run)
256
+ metric_tags.add(tag)
257
+ all_metric_evals = (
258
+ self._backend_context.read_last_scalars(
259
+ self._request_context,
260
+ self._experiment_id,
261
+ run_tag_filter=provider.RunTagFilter(
262
+ runs=metric_runs, tags=metric_tags
263
+ ),
264
+ )
265
+ if self._include_metrics
266
+ else {}
267
+ )
268
+ for (
269
+ session_name,
270
+ tag_to_content,
271
+ ) in hparams_run_to_tag_to_content.items():
272
+ if metadata.SESSION_START_INFO_TAG not in tag_to_content:
273
+ continue
274
+ start_info = metadata.parse_session_start_info_plugin_data(
275
+ tag_to_content[metadata.SESSION_START_INFO_TAG]
276
+ )
277
+ end_info = None
278
+ if metadata.SESSION_END_INFO_TAG in tag_to_content:
279
+ end_info = metadata.parse_session_end_info_plugin_data(
280
+ tag_to_content[metadata.SESSION_END_INFO_TAG]
281
+ )
282
+ session = self._build_session(
283
+ metric_infos,
284
+ session_name,
285
+ start_info,
286
+ end_info,
287
+ all_metric_evals,
288
+ )
289
+ if session.status in self._request.allowed_statuses:
290
+ self._add_session(session, start_info, groups_by_name)
291
+
292
+ # Compute the session group's aggregated metrics for each group.
293
+ groups = groups_by_name.values()
294
+ for group in groups:
295
+ # We sort the sessions in a group so that the order is deterministic.
296
+ group.sessions.sort(key=operator.attrgetter("name"))
297
+ self._aggregate_metrics(group)
298
+ return groups
299
+
300
+ def _add_session(self, session, start_info, groups_by_name):
301
+ """Adds a new Session protobuffer to the 'groups_by_name' dictionary.
302
+
303
+ Called by _build_session_groups when we encounter a new session. Creates
304
+ the Session protobuffer and adds it to the relevant group in the
305
+ 'groups_by_name' dict. Creates the session group if this is the first time
306
+ we encounter it.
307
+
308
+ Args:
309
+ session: api_pb2.Session. The session to add.
310
+ start_info: The SessionStartInfo protobuffer associated with the session.
311
+ groups_by_name: A str to SessionGroup protobuffer dict. Representing the
312
+ session groups and sessions found so far.
313
+ """
314
+ # If the group_name is empty, this session's group contains only
315
+ # this session. Use the session name for the group name since session
316
+ # names are unique.
317
+ group_name = start_info.group_name or session.name
318
+ if group_name in groups_by_name:
319
+ groups_by_name[group_name].sessions.extend([session])
320
+ else:
321
+ # Create the group and add the session as the first one.
322
+ group = api_pb2.SessionGroup(
323
+ name=group_name,
324
+ sessions=[session],
325
+ monitor_url=start_info.monitor_url,
326
+ )
327
+ # Copy hparams from the first session (all sessions should have the same
328
+ # hyperparameter values) into result.
329
+ # There doesn't seem to be a way to initialize a protobuffer map in the
330
+ # constructor.
331
+ for key, value in start_info.hparams.items():
332
+ if not json_format_compat.is_serializable_value(value):
333
+ # NaN number_value cannot be serialized by higher level layers
334
+ # that are using json_format.MessageToJson(). To workaround
335
+ # the issue we do not copy them to the session group and
336
+ # effectively treat them as "unset".
337
+ continue
338
+
339
+ group.hparams[key].CopyFrom(value)
340
+ groups_by_name[group_name] = group
341
+
342
+ def _build_session(
343
+ self, metric_infos, name, start_info, end_info, all_metric_evals
344
+ ):
345
+ """Builds a session object."""
346
+
347
+ assert start_info is not None
348
+ result = api_pb2.Session(
349
+ name=name,
350
+ start_time_secs=start_info.start_time_secs,
351
+ model_uri=start_info.model_uri,
352
+ metric_values=self._build_session_metric_values(
353
+ metric_infos, name, all_metric_evals
354
+ ),
355
+ monitor_url=start_info.monitor_url,
356
+ )
357
+ if end_info is not None:
358
+ result.status = end_info.status
359
+ result.end_time_secs = end_info.end_time_secs
360
+ return result
361
+
362
+ def _build_session_metric_values(
363
+ self, metric_infos, session_name, all_metric_evals
364
+ ):
365
+ """Builds the session metric values."""
366
+
367
+ # result is a list of api_pb2.MetricValue instances.
368
+ result = []
369
+ for metric_info in metric_infos:
370
+ metric_name = metric_info.name
371
+ run, tag = metrics.run_tag_from_session_and_metric(
372
+ session_name, metric_name
373
+ )
374
+ datum = all_metric_evals.get(run, {}).get(tag)
375
+ if not datum:
376
+ # It's ok if we don't find the metric in the session.
377
+ # We skip it here. For filtering and sorting purposes its value is None.
378
+ continue
379
+ result.append(
380
+ api_pb2.MetricValue(
381
+ name=metric_name,
382
+ wall_time_secs=datum.wall_time,
383
+ training_step=datum.step,
384
+ value=datum.value,
385
+ )
386
+ )
387
+ return result
388
+
389
+ def _aggregate_metrics(self, session_group):
390
+ """Sets the metrics of the group based on aggregation_type."""
391
+
392
+ if (
393
+ self._request.aggregation_type == api_pb2.AGGREGATION_AVG
394
+ or self._request.aggregation_type == api_pb2.AGGREGATION_UNSET
395
+ ):
396
+ _set_avg_session_metrics(session_group)
397
+ elif self._request.aggregation_type == api_pb2.AGGREGATION_MEDIAN:
398
+ _set_median_session_metrics(
399
+ session_group, self._request.aggregation_metric
400
+ )
401
+ elif self._request.aggregation_type == api_pb2.AGGREGATION_MIN:
402
+ _set_extremum_session_metrics(
403
+ session_group, self._request.aggregation_metric, min
404
+ )
405
+ elif self._request.aggregation_type == api_pb2.AGGREGATION_MAX:
406
+ _set_extremum_session_metrics(
407
+ session_group, self._request.aggregation_metric, max
408
+ )
409
+ else:
410
+ raise error.HParamsError(
411
+ "Unknown aggregation_type in request: %s"
412
+ % self._request.aggregation_type
413
+ )
414
+
415
+ def _filter(self, session_groups, filters):
416
+ return [
417
+ sg for sg in session_groups if self._passes_all_filters(sg, filters)
418
+ ]
419
+
420
+ def _passes_all_filters(self, session_group, filters):
421
+ return all(filter_fn(session_group) for filter_fn in filters)
422
+
423
+ def _sort(self, session_groups, extractors):
424
+ """Sorts 'session_groups' in place according to _request.col_params."""
425
+
426
+ # Sort by session_group name so we have a deterministic order.
427
+ session_groups.sort(key=operator.attrgetter("name"))
428
+ # Sort by lexicographical order of the _request.col_params whose order
429
+ # is not ORDER_UNSPECIFIED. The first such column is the primary sorting
430
+ # key, the second is the secondary sorting key, etc. To achieve that we
431
+ # need to iterate on these columns in reverse order (thus the primary key
432
+ # is the key used in the last sort).
433
+ for col_param, extractor in reversed(
434
+ list(zip(self._request.col_params, extractors))
435
+ ):
436
+ if col_param.order == api_pb2.ORDER_UNSPECIFIED:
437
+ continue
438
+ if col_param.order == api_pb2.ORDER_ASC:
439
+ session_groups.sort(
440
+ key=_create_key_func(
441
+ extractor,
442
+ none_is_largest=not col_param.missing_values_first,
443
+ )
444
+ )
445
+ elif col_param.order == api_pb2.ORDER_DESC:
446
+ session_groups.sort(
447
+ key=_create_key_func(
448
+ extractor,
449
+ none_is_largest=col_param.missing_values_first,
450
+ ),
451
+ reverse=True,
452
+ )
453
+ else:
454
+ raise error.HParamsError(
455
+ "Unknown col_param.order given: %s" % col_param
456
+ )
457
+
458
+ def _create_response(self, session_groups):
459
+ return api_pb2.ListSessionGroupsResponse(
460
+ session_groups=session_groups[
461
+ self._request.start_index : self._request.start_index
462
+ + self._request.slice_size
463
+ ],
464
+ total_size=len(session_groups),
465
+ )
466
+
467
+
468
+ def _create_key_func(extractor, none_is_largest):
469
+ """Returns a key_func to be used in list.sort().
470
+
471
+ Returns a key_func to be used in list.sort() that sorts session groups
472
+ by the value extracted by extractor. 'None' extracted values will either
473
+ be considered largest or smallest as specified by the "none_is_largest"
474
+ boolean parameter.
475
+
476
+ Args:
477
+ extractor: An extractor function that extract the key from the session
478
+ group.
479
+ none_is_largest: bool. If true treats 'None's as largest; otherwise
480
+ smallest.
481
+ """
482
+ if none_is_largest:
483
+
484
+ def key_func_none_is_largest(session_group):
485
+ value = extractor(session_group)
486
+ return (value is None, value)
487
+
488
+ return key_func_none_is_largest
489
+
490
+ def key_func_none_is_smallest(session_group):
491
+ value = extractor(session_group)
492
+ return (value is not None, value)
493
+
494
+ return key_func_none_is_smallest
495
+
496
+
497
+ # Extractors. An extractor is a function that extracts some property (a metric
498
+ # or a hyperparameter) from a SessionGroup instance.
499
+ def _create_extractors(col_params):
500
+ """Creates extractors to extract properties corresponding to 'col_params'.
501
+
502
+ Args:
503
+ col_params: List of ListSessionGroupsRequest.ColParam protobufs.
504
+ Returns:
505
+ A list of extractor functions. The ith element in the
506
+ returned list extracts the column corresponding to the ith element of
507
+ _request.col_params
508
+ """
509
+ result = []
510
+ for col_param in col_params:
511
+ result.append(_create_extractor(col_param))
512
+ return result
513
+
514
+
515
+ def _create_extractor(col_param):
516
+ if col_param.HasField("metric"):
517
+ return _create_metric_extractor(col_param.metric)
518
+ elif col_param.HasField("hparam"):
519
+ return _create_hparam_extractor(col_param.hparam)
520
+ else:
521
+ raise error.HParamsError(
522
+ 'Got ColParam with both "metric" and "hparam" fields unset: %s'
523
+ % col_param
524
+ )
525
+
526
+
527
+ def _create_metric_extractor(metric_name):
528
+ """Returns function that extracts a metric from a session group or a
529
+ session.
530
+
531
+ Args:
532
+ metric_name: tensorboard.hparams.MetricName protobuffer. Identifies the
533
+ metric to extract from the session group.
534
+ Returns:
535
+ A function that takes a tensorboard.hparams.SessionGroup or
536
+ tensorborad.hparams.Session protobuffer and returns the value of the metric
537
+ identified by 'metric_name' or None if the value doesn't exist.
538
+ """
539
+
540
+ def extractor_fn(session_or_group):
541
+ metric_value = _find_metric_value(session_or_group, metric_name)
542
+ return metric_value.value if metric_value else None
543
+
544
+ return extractor_fn
545
+
546
+
547
+ def _find_metric_value(session_or_group, metric_name):
548
+ """Returns the metric_value for a given metric in a session or session
549
+ group.
550
+
551
+ Args:
552
+ session_or_group: A Session protobuffer or SessionGroup protobuffer.
553
+ metric_name: A MetricName protobuffer. The metric to search for.
554
+ Returns:
555
+ A MetricValue protobuffer representing the value of the given metric or
556
+ None if no such metric was found in session_or_group.
557
+ """
558
+ # Note: We can speed this up by converting the metric_values field
559
+ # to a dictionary on initialization, to avoid a linear search here. We'll
560
+ # need to wrap the SessionGroup and Session protos in a python object for
561
+ # that.
562
+ for metric_value in session_or_group.metric_values:
563
+ if (
564
+ metric_value.name.tag == metric_name.tag
565
+ and metric_value.name.group == metric_name.group
566
+ ):
567
+ return metric_value
568
+
569
+
570
+ def _create_hparam_extractor(hparam_name):
571
+ """Returns an extractor function that extracts an hparam from a session
572
+ group.
573
+
574
+ Args:
575
+ hparam_name: str. Identies the hparam to extract from the session group.
576
+ Returns:
577
+ A function that takes a tensorboard.hparams.SessionGroup protobuffer and
578
+ returns the value, as a native Python object, of the hparam identified by
579
+ 'hparam_name'.
580
+ """
581
+
582
+ def extractor_fn(session_group):
583
+ if hparam_name in session_group.hparams:
584
+ return _value_to_python(session_group.hparams[hparam_name])
585
+ return None
586
+
587
+ return extractor_fn
588
+
589
+
590
+ # Filters. A filter is a boolean function that takes a session group and returns
591
+ # True if it should be included in the result. Currently, Filters are functions
592
+ # of a single column value extracted from the session group with a given
593
+ # extractor specified in the construction of the filter.
594
+ def _create_filters(col_params, extractors, *, include_hparam_filters=True):
595
+ """Creates filters for the given col_params.
596
+
597
+ Args:
598
+ col_params: List of ListSessionGroupsRequest.ColParam protobufs.
599
+ extractors: list of extractor functions of the same length as col_params.
600
+ Each element should extract the column described by the corresponding
601
+ element of col_params.
602
+ include_hparam_filters: bool that indicates whether hparam filters should
603
+ be generated. Defaults to True.
604
+ Returns:
605
+ A list of filter functions. Each corresponding to a single
606
+ col_params.filter oneof field of _request
607
+ """
608
+ result = []
609
+ for col_param, extractor in zip(col_params, extractors):
610
+ if not include_hparam_filters and col_param.hparam:
611
+ continue
612
+
613
+ a_filter = _create_filter(col_param, extractor)
614
+ if a_filter:
615
+ result.append(a_filter)
616
+ return result
617
+
618
+
619
+ def _create_filter(col_param, extractor):
620
+ """Creates a filter for the given col_param and extractor.
621
+
622
+ Args:
623
+ col_param: A tensorboard.hparams.ColParams object identifying the column
624
+ and describing the filter to apply.
625
+ extractor: A function that extract the column value identified by
626
+ 'col_param' from a tensorboard.hparams.SessionGroup protobuffer.
627
+ Returns:
628
+ A boolean function taking a tensorboard.hparams.SessionGroup protobuffer
629
+ returning True if the session group passes the filter described by
630
+ 'col_param'. If col_param does not specify a filter (i.e. any session
631
+ group passes) returns None.
632
+ """
633
+ include_missing_values = not col_param.exclude_missing_values
634
+ if col_param.HasField("filter_regexp"):
635
+ value_filter_fn = _create_regexp_filter(col_param.filter_regexp)
636
+ elif col_param.HasField("filter_interval"):
637
+ value_filter_fn = _create_interval_filter(col_param.filter_interval)
638
+ elif col_param.HasField("filter_discrete"):
639
+ value_filter_fn = _create_discrete_set_filter(col_param.filter_discrete)
640
+ elif include_missing_values:
641
+ # No 'filter' field and include_missing_values is True.
642
+ # Thus, the resulting filter always returns True, so to optimize for this
643
+ # common case we do not include it in the list of filters to check.
644
+ return None
645
+ else:
646
+ value_filter_fn = lambda _: True
647
+
648
+ def filter_fn(session_group):
649
+ value = extractor(session_group)
650
+ if value is None:
651
+ return include_missing_values
652
+ return value_filter_fn(value)
653
+
654
+ return filter_fn
655
+
656
+
657
+ def _create_regexp_filter(regex):
658
+ """Returns a boolean function that filters strings based on a regular exp.
659
+
660
+ Args:
661
+ regex: A string describing the regexp to use.
662
+ Returns:
663
+ A function taking a string and returns True if any of its substrings
664
+ matches regex.
665
+ """
666
+ # Warning: Note that python's regex library allows inputs that take
667
+ # exponential time. Time-limiting it is difficult. When we move to
668
+ # a true multi-tenant tensorboard server, the regexp implementation here
669
+ # would need to be replaced by something more secure.
670
+ compiled_regex = re.compile(regex)
671
+
672
+ def filter_fn(value):
673
+ if not isinstance(value, str):
674
+ raise error.HParamsError(
675
+ "Cannot use a regexp filter for a value of type %s. Value: %s"
676
+ % (type(value), value)
677
+ )
678
+ return re.search(compiled_regex, value) is not None
679
+
680
+ return filter_fn
681
+
682
+
683
+ def _create_interval_filter(interval):
684
+ """Returns a function that checkes whether a number belongs to an interval.
685
+
686
+ Args:
687
+ interval: A tensorboard.hparams.Interval protobuf describing the interval.
688
+ Returns:
689
+ A function taking a number (float or int) that returns True if the number
690
+ belongs to (the closed) 'interval'.
691
+ """
692
+
693
+ def filter_fn(value):
694
+ if not isinstance(value, (int, float)):
695
+ raise error.HParamsError(
696
+ "Cannot use an interval filter for a value of type: %s, Value: %s"
697
+ % (type(value), value)
698
+ )
699
+ return interval.min_value <= value and value <= interval.max_value
700
+
701
+ return filter_fn
702
+
703
+
704
+ def _create_discrete_set_filter(discrete_set):
705
+ """Returns a function that checks whether a value belongs to a set.
706
+
707
+ Args:
708
+ discrete_set: A list of objects representing the set.
709
+ Returns:
710
+ A function taking an object and returns True if its in the set. Membership
711
+ is tested using the Python 'in' operator (thus, equality of distinct
712
+ objects is computed using the '==' operator).
713
+ """
714
+
715
+ def filter_fn(value):
716
+ return value in discrete_set
717
+
718
+ return filter_fn
719
+
720
+
721
+ def _value_to_python(value):
722
+ """Converts a google.protobuf.Value to a native Python object."""
723
+
724
+ assert isinstance(value, struct_pb2.Value)
725
+ field = value.WhichOneof("kind")
726
+ if field == "number_value":
727
+ return value.number_value
728
+ elif field == "string_value":
729
+ return value.string_value
730
+ elif field == "bool_value":
731
+ return value.bool_value
732
+ else:
733
+ raise ValueError("Unknown struct_pb2.Value oneof field set: %s" % field)
734
+
735
+
736
+ @dataclasses.dataclass(frozen=True)
737
+ class _MetricIdentifier:
738
+ """An identifier for a metric.
739
+
740
+ As protobuffers are mutable we can't use MetricName directly as a dict's key.
741
+ Instead, we represent MetricName protocol buffer as an immutable dataclass.
742
+
743
+ Attributes:
744
+ group: Metric group corresponding to the dataset on which the model was
745
+ evaluated.
746
+ tag: String tag associated with the metric.
747
+ """
748
+
749
+ group: str
750
+ tag: str
751
+
752
+
753
+ class _MetricStats:
754
+ """A simple class to hold metric stats used in calculating metric averages.
755
+
756
+ Used in _set_avg_session_metrics(). See the comments in that function
757
+ for more details.
758
+
759
+ Attributes:
760
+ total: int. The sum of the metric measurements seen so far.
761
+ count: int. The number of largest-step measuremens seen so far.
762
+ total_step: int. The sum of the steps at which the measurements were taken
763
+ total_wall_time_secs: float. The sum of the wall_time_secs at
764
+ which the measurements were taken.
765
+ """
766
+
767
+ # We use slots here to catch typos in attributes earlier. Note that this makes
768
+ # this class incompatible with 'pickle'.
769
+ __slots__ = [
770
+ "total",
771
+ "count",
772
+ "total_step",
773
+ "total_wall_time_secs",
774
+ ]
775
+
776
+ def __init__(self):
777
+ self.total = 0
778
+ self.count = 0
779
+ self.total_step = 0
780
+ self.total_wall_time_secs = 0.0
781
+
782
+
783
+ def _set_avg_session_metrics(session_group):
784
+ """Sets the metrics for the group to be the average of its sessions.
785
+
786
+ The resulting session group metrics consist of the union of metrics across
787
+ the group's sessions. The value of each session group metric is the average
788
+ of that metric values across the sessions in the group. The 'step' and
789
+ 'wall_time_secs' fields of the resulting MetricValue field in the session
790
+ group are populated with the corresponding averages (truncated for 'step')
791
+ as well.
792
+
793
+ Args:
794
+ session_group: A SessionGroup protobuffer.
795
+ """
796
+ assert session_group.sessions, "SessionGroup cannot be empty."
797
+ # Algorithm: Iterate over all (session, metric) pairs and maintain a
798
+ # dict from _MetricIdentifier to _MetricStats objects.
799
+ # Then use the final dict state to compute the average for each metric.
800
+ metric_stats = collections.defaultdict(_MetricStats)
801
+ for session in session_group.sessions:
802
+ for metric_value in session.metric_values:
803
+ metric_name = _MetricIdentifier(
804
+ group=metric_value.name.group, tag=metric_value.name.tag
805
+ )
806
+ stats = metric_stats[metric_name]
807
+ stats.total += metric_value.value
808
+ stats.count += 1
809
+ stats.total_step += metric_value.training_step
810
+ stats.total_wall_time_secs += metric_value.wall_time_secs
811
+
812
+ del session_group.metric_values[:]
813
+ for metric_name, stats in metric_stats.items():
814
+ session_group.metric_values.add(
815
+ name=api_pb2.MetricName(
816
+ group=metric_name.group, tag=metric_name.tag
817
+ ),
818
+ value=float(stats.total) / float(stats.count),
819
+ training_step=stats.total_step // stats.count,
820
+ wall_time_secs=stats.total_wall_time_secs / stats.count,
821
+ )
822
+
823
+
824
+ @dataclasses.dataclass(frozen=True)
825
+ class _Measurement:
826
+ """Holds a session's metric value.
827
+
828
+ Attributes:
829
+ metric_value: Metric value of the session.
830
+ session_index: Index of the session in its group.
831
+ """
832
+
833
+ metric_value: Optional[api_pb2.MetricValue]
834
+ session_index: int
835
+
836
+
837
+ def _set_median_session_metrics(session_group, aggregation_metric):
838
+ """Sets the metrics for session_group to those of its "median session".
839
+
840
+ The median session is the session in session_group with the median value
841
+ of the metric given by 'aggregation_metric'. The median is taken over the
842
+ subset of sessions in the group whose 'aggregation_metric' was measured
843
+ at the largest training step among the sessions in the group.
844
+
845
+ Args:
846
+ session_group: A SessionGroup protobuffer.
847
+ aggregation_metric: A MetricName protobuffer.
848
+ """
849
+ measurements = sorted(
850
+ _measurements(session_group, aggregation_metric),
851
+ key=operator.attrgetter("metric_value.value"),
852
+ )
853
+ median_session = measurements[(len(measurements) - 1) // 2].session_index
854
+ del session_group.metric_values[:]
855
+ session_group.metric_values.MergeFrom(
856
+ session_group.sessions[median_session].metric_values
857
+ )
858
+
859
+
860
+ def _set_extremum_session_metrics(
861
+ session_group, aggregation_metric, extremum_fn
862
+ ):
863
+ """Sets the metrics for session_group to those of its "extremum session".
864
+
865
+ The extremum session is the session in session_group with the extremum value
866
+ of the metric given by 'aggregation_metric'. The extremum is taken over the
867
+ subset of sessions in the group whose 'aggregation_metric' was measured
868
+ at the largest training step among the sessions in the group.
869
+
870
+ Args:
871
+ session_group: A SessionGroup protobuffer.
872
+ aggregation_metric: A MetricName protobuffer.
873
+ extremum_fn: callable. Must be either 'min' or 'max'. Determines the type of
874
+ extremum to compute.
875
+ """
876
+ measurements = _measurements(session_group, aggregation_metric)
877
+ ext_session = extremum_fn(
878
+ measurements, key=operator.attrgetter("metric_value.value")
879
+ ).session_index
880
+ del session_group.metric_values[:]
881
+ session_group.metric_values.MergeFrom(
882
+ session_group.sessions[ext_session].metric_values
883
+ )
884
+
885
+
886
+ def _measurements(session_group, metric_name):
887
+ """A generator for the values of the metric across the sessions in the
888
+ group.
889
+
890
+ Args:
891
+ session_group: A SessionGroup protobuffer.
892
+ metric_name: A MetricName protobuffer.
893
+ Yields:
894
+ The next metric value wrapped in a _Measurement instance.
895
+ """
896
+ for session_index, session in enumerate(session_group.sessions):
897
+ metric_value = _find_metric_value(session, metric_name)
898
+ if not metric_value:
899
+ continue
900
+ yield _Measurement(metric_value, session_index)
901
+
902
+
903
+ def _build_data_provider_filters(col_params):
904
+ """Builds HyperparameterFilters from ColParams."""
905
+ filters = []
906
+ for col_param in col_params:
907
+ if not col_param.hparam:
908
+ # We do not pass metric filters to the DataProvider as it does not
909
+ # have the metric data for filtering.
910
+ continue
911
+
912
+ fltr = _build_data_provider_filter(col_param)
913
+ if fltr is None:
914
+ continue
915
+ filters.append(fltr)
916
+ return filters
917
+
918
+
919
+ def _build_data_provider_filter(col_param):
920
+ """Builds HyperparameterFilter from ColParam.
921
+
922
+ Args:
923
+ col_param: ColParam that possibly contains filter information.
924
+
925
+ Returns:
926
+ None if col_param does not specify filter information.
927
+ """
928
+ if col_param.HasField("filter_regexp"):
929
+ filter_type = provider.HyperparameterFilterType.REGEX
930
+ fltr = col_param.filter_regexp
931
+ elif col_param.HasField("filter_interval"):
932
+ filter_type = provider.HyperparameterFilterType.INTERVAL
933
+ fltr = (
934
+ col_param.filter_interval.min_value,
935
+ col_param.filter_interval.max_value,
936
+ )
937
+ elif col_param.HasField("filter_discrete"):
938
+ filter_type = provider.HyperparameterFilterType.DISCRETE
939
+ fltr = [_value_to_python(b) for b in col_param.filter_discrete.values]
940
+ else:
941
+ return None
942
+
943
+ return provider.HyperparameterFilter(
944
+ hyperparameter_name=col_param.hparam,
945
+ filter_type=filter_type,
946
+ filter=fltr,
947
+ )
948
+
949
+
950
+ def _build_data_provider_sort(col_params):
951
+ """Builds HyperparameterSorts from ColParams."""
952
+ sort = []
953
+ for col_param in col_params:
954
+ sort_item = _build_data_provider_sort_item(col_param)
955
+ if sort_item is None:
956
+ continue
957
+ sort.append(sort_item)
958
+ return sort
959
+
960
+
961
+ def _build_data_provider_sort_item(col_param):
962
+ """Builds HyperparameterSort from ColParam.
963
+
964
+ Args:
965
+ col_param: ColParam that possibly contains sort information.
966
+
967
+ Returns:
968
+ None if col_param does not specify sort information.
969
+ """
970
+ if col_param.order == api_pb2.ORDER_UNSPECIFIED:
971
+ return None
972
+
973
+ sort_direction = (
974
+ provider.HyperparameterSortDirection.ASCENDING
975
+ if col_param.order == api_pb2.ORDER_ASC
976
+ else provider.HyperparameterSortDirection.DESCENDING
977
+ )
978
+ return provider.HyperparameterSort(
979
+ hyperparameter_name=col_param.hparam,
980
+ sort_direction=sort_direction,
981
+ )
982
+
983
+
984
+ def _specifies_include(col_params):
985
+ """Determines whether any `ColParam` contains the `include_in_result` field.
986
+
987
+ In the case where none of the col_params contains the field, we should assume
988
+ that all fields should be included in the response.
989
+ """
990
+ return any(
991
+ col_param.HasField("include_in_result") for col_param in col_params
992
+ )
993
+
994
+
995
+ def _get_hparams_to_include(col_params):
996
+ """Generates the list of hparams to include in the response.
997
+
998
+ The determination is based on the `include_in_result` field in ColParam. If
999
+ a ColParam either has `include_in_result: True` or does not specify the
1000
+ field at all, then it should be included in the result.
1001
+
1002
+ Args:
1003
+ col_params: A collection of `ColParams` protos.
1004
+
1005
+ Returns:
1006
+ A list of names of hyperparameters to include in the response.
1007
+ """
1008
+ hparams_to_include = []
1009
+ for col_param in col_params:
1010
+ if (
1011
+ col_param.HasField("include_in_result")
1012
+ and not col_param.include_in_result
1013
+ ):
1014
+ # Explicitly set to exclude this hparam.
1015
+ continue
1016
+ if col_param.hparam:
1017
+ hparams_to_include.append(col_param.hparam)
1018
+ return hparams_to_include
1019
+
1020
+
1021
+ def _reduce_to_hparams_to_include(session_groups, col_params):
1022
+ """Removes hparams from session_groups that should not be included.
1023
+
1024
+ Args:
1025
+ session_groups: A collection of `SessionGroup` protos, which will be
1026
+ modified in place.
1027
+ col_params: A collection of `ColParams` protos.
1028
+ """
1029
+ hparams_to_include = _get_hparams_to_include(col_params)
1030
+
1031
+ for session_group in session_groups:
1032
+ new_hparams = {
1033
+ hparam: value
1034
+ for (hparam, value) in session_group.hparams.items()
1035
+ if hparam in hparams_to_include
1036
+ }
1037
+
1038
+ session_group.ClearField("hparams")
1039
+ for hparam, value in new_hparams.items():
1040
+ session_group.hparams[hparam].CopyFrom(value)