tensorbored 2.21.0rc1769983804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensorbored/__init__.py +112 -0
- tensorbored/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/__init__.py +125 -0
- tensorbored/_vendor/bleach/_vendor/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/__init__.py +35 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_ihatexml.py +289 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_inputstream.py +918 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_tokenizer.py +1735 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/__init__.py +5 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/_base.py +40 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_trie/py.py +67 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/_utils.py +159 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/constants.py +2946 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/__init__.py +0 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/alphabeticalattributes.py +29 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/base.py +12 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/inject_meta_charset.py +73 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/lint.py +93 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/optionaltags.py +207 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/sanitizer.py +916 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/filters/whitespace.py +38 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/html5parser.py +2795 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/serializer.py +409 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/__init__.py +30 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/genshi.py +54 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treeadapters/sax.py +50 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/__init__.py +88 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/base.py +417 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/dom.py +239 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree.py +343 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treebuilders/etree_lxml.py +392 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/__init__.py +154 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/base.py +252 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/dom.py +43 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree.py +131 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/etree_lxml.py +215 -0
- tensorbored/_vendor/bleach/_vendor/html5lib/treewalkers/genshi.py +69 -0
- tensorbored/_vendor/bleach/_vendor/parse.py +1078 -0
- tensorbored/_vendor/bleach/callbacks.py +32 -0
- tensorbored/_vendor/bleach/html5lib_shim.py +757 -0
- tensorbored/_vendor/bleach/linkifier.py +633 -0
- tensorbored/_vendor/bleach/parse_shim.py +1 -0
- tensorbored/_vendor/bleach/sanitizer.py +638 -0
- tensorbored/_vendor/bleach/six_shim.py +19 -0
- tensorbored/_vendor/webencodings/__init__.py +342 -0
- tensorbored/_vendor/webencodings/labels.py +231 -0
- tensorbored/_vendor/webencodings/mklabels.py +59 -0
- tensorbored/_vendor/webencodings/x_user_defined.py +325 -0
- tensorbored/assets.py +36 -0
- tensorbored/auth.py +102 -0
- tensorbored/backend/__init__.py +0 -0
- tensorbored/backend/application.py +604 -0
- tensorbored/backend/auth_context_middleware.py +38 -0
- tensorbored/backend/client_feature_flags.py +113 -0
- tensorbored/backend/empty_path_redirect.py +46 -0
- tensorbored/backend/event_processing/__init__.py +0 -0
- tensorbored/backend/event_processing/data_ingester.py +276 -0
- tensorbored/backend/event_processing/data_provider.py +535 -0
- tensorbored/backend/event_processing/directory_loader.py +142 -0
- tensorbored/backend/event_processing/directory_watcher.py +272 -0
- tensorbored/backend/event_processing/event_accumulator.py +950 -0
- tensorbored/backend/event_processing/event_file_inspector.py +463 -0
- tensorbored/backend/event_processing/event_file_loader.py +292 -0
- tensorbored/backend/event_processing/event_multiplexer.py +521 -0
- tensorbored/backend/event_processing/event_util.py +68 -0
- tensorbored/backend/event_processing/io_wrapper.py +223 -0
- tensorbored/backend/event_processing/plugin_asset_util.py +104 -0
- tensorbored/backend/event_processing/plugin_event_accumulator.py +721 -0
- tensorbored/backend/event_processing/plugin_event_multiplexer.py +522 -0
- tensorbored/backend/event_processing/reservoir.py +266 -0
- tensorbored/backend/event_processing/tag_types.py +29 -0
- tensorbored/backend/experiment_id.py +71 -0
- tensorbored/backend/experimental_plugin.py +51 -0
- tensorbored/backend/http_util.py +263 -0
- tensorbored/backend/json_util.py +70 -0
- tensorbored/backend/path_prefix.py +67 -0
- tensorbored/backend/process_graph.py +74 -0
- tensorbored/backend/security_validator.py +202 -0
- tensorbored/compat/__init__.py +69 -0
- tensorbored/compat/proto/__init__.py +0 -0
- tensorbored/compat/proto/allocation_description_pb2.py +35 -0
- tensorbored/compat/proto/api_def_pb2.py +82 -0
- tensorbored/compat/proto/attr_value_pb2.py +80 -0
- tensorbored/compat/proto/cluster_pb2.py +58 -0
- tensorbored/compat/proto/config_pb2.py +271 -0
- tensorbored/compat/proto/coordination_config_pb2.py +45 -0
- tensorbored/compat/proto/cost_graph_pb2.py +87 -0
- tensorbored/compat/proto/cpp_shape_inference_pb2.py +70 -0
- tensorbored/compat/proto/debug_pb2.py +65 -0
- tensorbored/compat/proto/event_pb2.py +149 -0
- tensorbored/compat/proto/full_type_pb2.py +74 -0
- tensorbored/compat/proto/function_pb2.py +157 -0
- tensorbored/compat/proto/graph_debug_info_pb2.py +111 -0
- tensorbored/compat/proto/graph_pb2.py +41 -0
- tensorbored/compat/proto/histogram_pb2.py +39 -0
- tensorbored/compat/proto/meta_graph_pb2.py +254 -0
- tensorbored/compat/proto/node_def_pb2.py +61 -0
- tensorbored/compat/proto/op_def_pb2.py +81 -0
- tensorbored/compat/proto/resource_handle_pb2.py +48 -0
- tensorbored/compat/proto/rewriter_config_pb2.py +93 -0
- tensorbored/compat/proto/rpc_options_pb2.py +35 -0
- tensorbored/compat/proto/saved_object_graph_pb2.py +193 -0
- tensorbored/compat/proto/saver_pb2.py +38 -0
- tensorbored/compat/proto/step_stats_pb2.py +116 -0
- tensorbored/compat/proto/struct_pb2.py +144 -0
- tensorbored/compat/proto/summary_pb2.py +111 -0
- tensorbored/compat/proto/tensor_description_pb2.py +38 -0
- tensorbored/compat/proto/tensor_pb2.py +68 -0
- tensorbored/compat/proto/tensor_shape_pb2.py +46 -0
- tensorbored/compat/proto/tfprof_log_pb2.py +307 -0
- tensorbored/compat/proto/trackable_object_graph_pb2.py +90 -0
- tensorbored/compat/proto/types_pb2.py +105 -0
- tensorbored/compat/proto/variable_pb2.py +62 -0
- tensorbored/compat/proto/verifier_config_pb2.py +38 -0
- tensorbored/compat/proto/versions_pb2.py +35 -0
- tensorbored/compat/tensorflow_stub/__init__.py +38 -0
- tensorbored/compat/tensorflow_stub/app.py +124 -0
- tensorbored/compat/tensorflow_stub/compat/__init__.py +131 -0
- tensorbored/compat/tensorflow_stub/compat/v1/__init__.py +20 -0
- tensorbored/compat/tensorflow_stub/dtypes.py +692 -0
- tensorbored/compat/tensorflow_stub/error_codes.py +169 -0
- tensorbored/compat/tensorflow_stub/errors.py +507 -0
- tensorbored/compat/tensorflow_stub/flags.py +124 -0
- tensorbored/compat/tensorflow_stub/io/__init__.py +17 -0
- tensorbored/compat/tensorflow_stub/io/gfile.py +1011 -0
- tensorbored/compat/tensorflow_stub/pywrap_tensorflow.py +285 -0
- tensorbored/compat/tensorflow_stub/tensor_shape.py +1035 -0
- tensorbored/context.py +129 -0
- tensorbored/data/__init__.py +0 -0
- tensorbored/data/grpc_provider.py +365 -0
- tensorbored/data/ingester.py +46 -0
- tensorbored/data/proto/__init__.py +0 -0
- tensorbored/data/proto/data_provider_pb2.py +517 -0
- tensorbored/data/proto/data_provider_pb2_grpc.py +374 -0
- tensorbored/data/provider.py +1365 -0
- tensorbored/data/server_ingester.py +301 -0
- tensorbored/data_compat.py +159 -0
- tensorbored/dataclass_compat.py +224 -0
- tensorbored/default.py +124 -0
- tensorbored/errors.py +130 -0
- tensorbored/lazy.py +99 -0
- tensorbored/main.py +48 -0
- tensorbored/main_lib.py +62 -0
- tensorbored/manager.py +487 -0
- tensorbored/notebook.py +441 -0
- tensorbored/plugin_util.py +266 -0
- tensorbored/plugins/__init__.py +0 -0
- tensorbored/plugins/audio/__init__.py +0 -0
- tensorbored/plugins/audio/audio_plugin.py +229 -0
- tensorbored/plugins/audio/metadata.py +69 -0
- tensorbored/plugins/audio/plugin_data_pb2.py +37 -0
- tensorbored/plugins/audio/summary.py +230 -0
- tensorbored/plugins/audio/summary_v2.py +124 -0
- tensorbored/plugins/base_plugin.py +367 -0
- tensorbored/plugins/core/__init__.py +0 -0
- tensorbored/plugins/core/core_plugin.py +981 -0
- tensorbored/plugins/custom_scalar/__init__.py +0 -0
- tensorbored/plugins/custom_scalar/custom_scalars_plugin.py +320 -0
- tensorbored/plugins/custom_scalar/layout_pb2.py +85 -0
- tensorbored/plugins/custom_scalar/metadata.py +35 -0
- tensorbored/plugins/custom_scalar/summary.py +79 -0
- tensorbored/plugins/debugger_v2/__init__.py +0 -0
- tensorbored/plugins/debugger_v2/debug_data_multiplexer.py +631 -0
- tensorbored/plugins/debugger_v2/debug_data_provider.py +634 -0
- tensorbored/plugins/debugger_v2/debugger_v2_plugin.py +504 -0
- tensorbored/plugins/distribution/__init__.py +0 -0
- tensorbored/plugins/distribution/compressor.py +158 -0
- tensorbored/plugins/distribution/distributions_plugin.py +116 -0
- tensorbored/plugins/distribution/metadata.py +19 -0
- tensorbored/plugins/graph/__init__.py +0 -0
- tensorbored/plugins/graph/graph_util.py +129 -0
- tensorbored/plugins/graph/graphs_plugin.py +336 -0
- tensorbored/plugins/graph/keras_util.py +328 -0
- tensorbored/plugins/graph/metadata.py +42 -0
- tensorbored/plugins/histogram/__init__.py +0 -0
- tensorbored/plugins/histogram/histograms_plugin.py +144 -0
- tensorbored/plugins/histogram/metadata.py +63 -0
- tensorbored/plugins/histogram/plugin_data_pb2.py +34 -0
- tensorbored/plugins/histogram/summary.py +234 -0
- tensorbored/plugins/histogram/summary_v2.py +292 -0
- tensorbored/plugins/hparams/__init__.py +14 -0
- tensorbored/plugins/hparams/_keras.py +93 -0
- tensorbored/plugins/hparams/api.py +130 -0
- tensorbored/plugins/hparams/api_pb2.py +208 -0
- tensorbored/plugins/hparams/backend_context.py +606 -0
- tensorbored/plugins/hparams/download_data.py +158 -0
- tensorbored/plugins/hparams/error.py +26 -0
- tensorbored/plugins/hparams/get_experiment.py +71 -0
- tensorbored/plugins/hparams/hparams_plugin.py +206 -0
- tensorbored/plugins/hparams/hparams_util_pb2.py +69 -0
- tensorbored/plugins/hparams/json_format_compat.py +38 -0
- tensorbored/plugins/hparams/list_metric_evals.py +57 -0
- tensorbored/plugins/hparams/list_session_groups.py +1040 -0
- tensorbored/plugins/hparams/metadata.py +125 -0
- tensorbored/plugins/hparams/metrics.py +41 -0
- tensorbored/plugins/hparams/plugin_data_pb2.py +69 -0
- tensorbored/plugins/hparams/summary.py +205 -0
- tensorbored/plugins/hparams/summary_v2.py +597 -0
- tensorbored/plugins/image/__init__.py +0 -0
- tensorbored/plugins/image/images_plugin.py +232 -0
- tensorbored/plugins/image/metadata.py +65 -0
- tensorbored/plugins/image/plugin_data_pb2.py +34 -0
- tensorbored/plugins/image/summary.py +159 -0
- tensorbored/plugins/image/summary_v2.py +130 -0
- tensorbored/plugins/mesh/__init__.py +14 -0
- tensorbored/plugins/mesh/mesh_plugin.py +292 -0
- tensorbored/plugins/mesh/metadata.py +152 -0
- tensorbored/plugins/mesh/plugin_data_pb2.py +37 -0
- tensorbored/plugins/mesh/summary.py +251 -0
- tensorbored/plugins/mesh/summary_v2.py +214 -0
- tensorbored/plugins/metrics/__init__.py +0 -0
- tensorbored/plugins/metrics/metadata.py +17 -0
- tensorbored/plugins/metrics/metrics_plugin.py +623 -0
- tensorbored/plugins/pr_curve/__init__.py +0 -0
- tensorbored/plugins/pr_curve/metadata.py +75 -0
- tensorbored/plugins/pr_curve/plugin_data_pb2.py +34 -0
- tensorbored/plugins/pr_curve/pr_curves_plugin.py +241 -0
- tensorbored/plugins/pr_curve/summary.py +574 -0
- tensorbored/plugins/profile_redirect/__init__.py +0 -0
- tensorbored/plugins/profile_redirect/profile_redirect_plugin.py +49 -0
- tensorbored/plugins/projector/__init__.py +67 -0
- tensorbored/plugins/projector/metadata.py +26 -0
- tensorbored/plugins/projector/projector_config_pb2.py +54 -0
- tensorbored/plugins/projector/projector_plugin.py +795 -0
- tensorbored/plugins/projector/tf_projector_plugin/index.js +32 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.html +524 -0
- tensorbored/plugins/projector/tf_projector_plugin/projector_binary.js +15536 -0
- tensorbored/plugins/scalar/__init__.py +0 -0
- tensorbored/plugins/scalar/metadata.py +60 -0
- tensorbored/plugins/scalar/plugin_data_pb2.py +34 -0
- tensorbored/plugins/scalar/scalars_plugin.py +181 -0
- tensorbored/plugins/scalar/summary.py +109 -0
- tensorbored/plugins/scalar/summary_v2.py +124 -0
- tensorbored/plugins/text/__init__.py +0 -0
- tensorbored/plugins/text/metadata.py +62 -0
- tensorbored/plugins/text/plugin_data_pb2.py +34 -0
- tensorbored/plugins/text/summary.py +114 -0
- tensorbored/plugins/text/summary_v2.py +124 -0
- tensorbored/plugins/text/text_plugin.py +288 -0
- tensorbored/plugins/wit_redirect/__init__.py +0 -0
- tensorbored/plugins/wit_redirect/wit_redirect_plugin.py +49 -0
- tensorbored/program.py +910 -0
- tensorbored/summary/__init__.py +35 -0
- tensorbored/summary/_output.py +124 -0
- tensorbored/summary/_tf/__init__.py +14 -0
- tensorbored/summary/_tf/summary/__init__.py +178 -0
- tensorbored/summary/_writer.py +105 -0
- tensorbored/summary/v1.py +51 -0
- tensorbored/summary/v2.py +25 -0
- tensorbored/summary/writer/__init__.py +13 -0
- tensorbored/summary/writer/event_file_writer.py +291 -0
- tensorbored/summary/writer/record_writer.py +50 -0
- tensorbored/util/__init__.py +0 -0
- tensorbored/util/encoder.py +116 -0
- tensorbored/util/grpc_util.py +311 -0
- tensorbored/util/img_mime_type_detector.py +40 -0
- tensorbored/util/io_util.py +20 -0
- tensorbored/util/lazy_tensor_creator.py +110 -0
- tensorbored/util/op_evaluator.py +104 -0
- tensorbored/util/platform_util.py +20 -0
- tensorbored/util/tb_logging.py +24 -0
- tensorbored/util/tensor_util.py +617 -0
- tensorbored/util/timing.py +122 -0
- tensorbored/version.py +21 -0
- tensorbored/webfiles.zip +0 -0
- tensorbored-2.21.0rc1769983804.dist-info/METADATA +49 -0
- tensorbored-2.21.0rc1769983804.dist-info/RECORD +271 -0
- tensorbored-2.21.0rc1769983804.dist-info/WHEEL +5 -0
- tensorbored-2.21.0rc1769983804.dist-info/entry_points.txt +6 -0
- tensorbored-2.21.0rc1769983804.dist-info/licenses/LICENSE +739 -0
- tensorbored-2.21.0rc1769983804.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1040 @@
|
|
|
1
|
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Classes and functions for handling the ListSessionGroups API call."""
|
|
16
|
+
|
|
17
|
+
import collections
|
|
18
|
+
import dataclasses
|
|
19
|
+
import operator
|
|
20
|
+
import re
|
|
21
|
+
from typing import Optional
|
|
22
|
+
|
|
23
|
+
from google.protobuf import struct_pb2
|
|
24
|
+
|
|
25
|
+
from tensorbored.data import provider
|
|
26
|
+
from tensorbored.plugins.hparams import api_pb2
|
|
27
|
+
from tensorbored.plugins.hparams import backend_context as backend_context_lib
|
|
28
|
+
from tensorbored.plugins.hparams import error
|
|
29
|
+
from tensorbored.plugins.hparams import json_format_compat
|
|
30
|
+
from tensorbored.plugins.hparams import metadata
|
|
31
|
+
from tensorbored.plugins.hparams import metrics
|
|
32
|
+
from tensorbored.plugins.hparams import plugin_data_pb2
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Handler:
|
|
36
|
+
"""Handles a ListSessionGroups request."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self, request_context, backend_context, experiment_id, request
|
|
40
|
+
):
|
|
41
|
+
"""Constructor.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
request_context: A tensorboard.context.RequestContext.
|
|
45
|
+
backend_context: A backend_context.Context instance.
|
|
46
|
+
experiment_id: A string, as from `plugin_util.experiment_id`.
|
|
47
|
+
request: A ListSessionGroupsRequest protobuf.
|
|
48
|
+
"""
|
|
49
|
+
self._request_context = request_context
|
|
50
|
+
self._backend_context = backend_context
|
|
51
|
+
self._experiment_id = experiment_id
|
|
52
|
+
self._request = request
|
|
53
|
+
self._include_metrics = (
|
|
54
|
+
# Metrics are included by default if include_metrics is not
|
|
55
|
+
# specified in the request.
|
|
56
|
+
not self._request.HasField("include_metrics")
|
|
57
|
+
or self._request.include_metrics
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def run(self):
|
|
61
|
+
"""Handles the request specified on construction.
|
|
62
|
+
|
|
63
|
+
This operation first attempts to construct SessionGroup information
|
|
64
|
+
from hparam tags metadata.EXPERIMENT_TAG and
|
|
65
|
+
metadata.SESSION_START_INFO.
|
|
66
|
+
|
|
67
|
+
If no such tags are found, then will build SessionGroup information
|
|
68
|
+
using the results from DataProvider.read_hyperparameters().
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
A ListSessionGroupsResponse object.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
session_groups_from_tags = self._session_groups_from_tags()
|
|
75
|
+
if session_groups_from_tags:
|
|
76
|
+
return self._create_response(session_groups_from_tags)
|
|
77
|
+
|
|
78
|
+
session_groups_from_data_provider = (
|
|
79
|
+
self._session_groups_from_data_provider()
|
|
80
|
+
)
|
|
81
|
+
if session_groups_from_data_provider:
|
|
82
|
+
return self._create_response(session_groups_from_data_provider)
|
|
83
|
+
|
|
84
|
+
return api_pb2.ListSessionGroupsResponse(
|
|
85
|
+
session_groups=[], total_size=0
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _session_groups_from_tags(self):
|
|
89
|
+
"""Constructs lists of SessionGroups based on hparam tag metadata."""
|
|
90
|
+
# Query for all Hparams summary metadata one time to minimize calls to
|
|
91
|
+
# the underlying DataProvider.
|
|
92
|
+
hparams_run_to_tag_to_content = self._backend_context.hparams_metadata(
|
|
93
|
+
self._request_context, self._experiment_id
|
|
94
|
+
)
|
|
95
|
+
# Construct the experiment one time since an context.experiment() call
|
|
96
|
+
# may search through all the runs.
|
|
97
|
+
experiment = self._backend_context.experiment_from_metadata(
|
|
98
|
+
self._request_context,
|
|
99
|
+
self._experiment_id,
|
|
100
|
+
self._include_metrics,
|
|
101
|
+
hparams_run_to_tag_to_content,
|
|
102
|
+
# Don't pass any information from the DataProvider since we are only
|
|
103
|
+
# examining session groups based on tag metadata
|
|
104
|
+
provider.ListHyperparametersResult(
|
|
105
|
+
hyperparameters=[], session_groups=[]
|
|
106
|
+
),
|
|
107
|
+
)
|
|
108
|
+
extractors = _create_extractors(self._request.col_params)
|
|
109
|
+
filters = _create_filters(self._request.col_params, extractors)
|
|
110
|
+
|
|
111
|
+
session_groups = self._build_session_groups(
|
|
112
|
+
hparams_run_to_tag_to_content, experiment.metric_infos
|
|
113
|
+
)
|
|
114
|
+
session_groups = self._filter(session_groups, filters)
|
|
115
|
+
self._sort(session_groups, extractors)
|
|
116
|
+
|
|
117
|
+
if _specifies_include(self._request.col_params):
|
|
118
|
+
_reduce_to_hparams_to_include(
|
|
119
|
+
session_groups, self._request.col_params
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return session_groups
|
|
123
|
+
|
|
124
|
+
def _session_groups_from_data_provider(self):
|
|
125
|
+
"""Constructs lists of SessionGroups based on DataProvider results."""
|
|
126
|
+
filters = _build_data_provider_filters(self._request.col_params)
|
|
127
|
+
sort = _build_data_provider_sort(self._request.col_params)
|
|
128
|
+
hparams_to_include = (
|
|
129
|
+
_get_hparams_to_include(self._request.col_params)
|
|
130
|
+
if _specifies_include(self._request.col_params)
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
response = self._backend_context.session_groups_from_data_provider(
|
|
134
|
+
self._request_context,
|
|
135
|
+
self._experiment_id,
|
|
136
|
+
filters,
|
|
137
|
+
sort,
|
|
138
|
+
hparams_to_include,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
metric_infos = (
|
|
142
|
+
self._backend_context.compute_metric_infos_from_data_provider_session_groups(
|
|
143
|
+
self._request_context, self._experiment_id, response
|
|
144
|
+
)
|
|
145
|
+
if self._include_metrics
|
|
146
|
+
else []
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
all_metric_evals = (
|
|
150
|
+
self._backend_context.read_last_scalars(
|
|
151
|
+
self._request_context,
|
|
152
|
+
self._experiment_id,
|
|
153
|
+
run_tag_filter=None,
|
|
154
|
+
)
|
|
155
|
+
if self._include_metrics
|
|
156
|
+
else {}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
session_groups = []
|
|
160
|
+
for provider_group in response:
|
|
161
|
+
sessions = []
|
|
162
|
+
for session in provider_group.sessions:
|
|
163
|
+
session_name = (
|
|
164
|
+
backend_context_lib.generate_data_provider_session_name(
|
|
165
|
+
session
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
sessions.append(
|
|
169
|
+
self._build_session(
|
|
170
|
+
metric_infos,
|
|
171
|
+
session_name,
|
|
172
|
+
plugin_data_pb2.SessionStartInfo(),
|
|
173
|
+
plugin_data_pb2.SessionEndInfo(),
|
|
174
|
+
all_metric_evals,
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
name = backend_context_lib.generate_data_provider_session_name(
|
|
179
|
+
provider_group.root
|
|
180
|
+
)
|
|
181
|
+
if not name:
|
|
182
|
+
name = self._experiment_id
|
|
183
|
+
session_group = api_pb2.SessionGroup(
|
|
184
|
+
name=name,
|
|
185
|
+
sessions=sessions,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
for provider_hparam in provider_group.hyperparameter_values:
|
|
189
|
+
hparam = session_group.hparams[
|
|
190
|
+
provider_hparam.hyperparameter_name
|
|
191
|
+
]
|
|
192
|
+
if (
|
|
193
|
+
provider_hparam.domain_type
|
|
194
|
+
== provider.HyperparameterDomainType.DISCRETE_STRING
|
|
195
|
+
):
|
|
196
|
+
hparam.string_value = provider_hparam.value
|
|
197
|
+
elif provider_hparam.domain_type in [
|
|
198
|
+
provider.HyperparameterDomainType.DISCRETE_FLOAT,
|
|
199
|
+
provider.HyperparameterDomainType.INTERVAL,
|
|
200
|
+
]:
|
|
201
|
+
hparam.number_value = provider_hparam.value
|
|
202
|
+
elif (
|
|
203
|
+
provider_hparam.domain_type
|
|
204
|
+
== provider.HyperparameterDomainType.DISCRETE_BOOL
|
|
205
|
+
):
|
|
206
|
+
hparam.bool_value = provider_hparam.value
|
|
207
|
+
|
|
208
|
+
session_groups.append(session_group)
|
|
209
|
+
|
|
210
|
+
# Compute the session group's aggregated metrics for each group.
|
|
211
|
+
for group in session_groups:
|
|
212
|
+
if group.sessions:
|
|
213
|
+
self._aggregate_metrics(group)
|
|
214
|
+
|
|
215
|
+
extractors = _create_extractors(self._request.col_params)
|
|
216
|
+
filters = _create_filters(
|
|
217
|
+
self._request.col_params,
|
|
218
|
+
extractors,
|
|
219
|
+
# We assume the DataProvider will apply hparam filters and we do not
|
|
220
|
+
# attempt to reapply them.
|
|
221
|
+
include_hparam_filters=False,
|
|
222
|
+
)
|
|
223
|
+
session_groups = self._filter(session_groups, filters)
|
|
224
|
+
return session_groups
|
|
225
|
+
|
|
226
|
+
def _build_session_groups(
|
|
227
|
+
self, hparams_run_to_tag_to_content, metric_infos
|
|
228
|
+
):
|
|
229
|
+
"""Returns a list of SessionGroups protobuffers from the summary
|
|
230
|
+
data."""
|
|
231
|
+
|
|
232
|
+
# Algorithm: We keep a dict 'groups_by_name' mapping a SessionGroup name
|
|
233
|
+
# (str) to a SessionGroup protobuffer. We traverse the runs associated with
|
|
234
|
+
# the plugin--each representing a single session. We form a Session
|
|
235
|
+
# protobuffer from each run and add it to the relevant SessionGroup object
|
|
236
|
+
# in the 'groups_by_name' dict. We create the SessionGroup object, if this
|
|
237
|
+
# is the first session of that group we encounter.
|
|
238
|
+
groups_by_name = {}
|
|
239
|
+
# The TensorBoard runs with session start info are the
|
|
240
|
+
# "sessions", which are not necessarily the runs that actually
|
|
241
|
+
# contain metrics (may be in subdirectories).
|
|
242
|
+
session_names = [
|
|
243
|
+
run
|
|
244
|
+
for (run, tags) in hparams_run_to_tag_to_content.items()
|
|
245
|
+
if metadata.SESSION_START_INFO_TAG in tags
|
|
246
|
+
]
|
|
247
|
+
metric_runs = set()
|
|
248
|
+
metric_tags = set()
|
|
249
|
+
for session_name in session_names:
|
|
250
|
+
for metric in metric_infos:
|
|
251
|
+
metric_name = metric.name
|
|
252
|
+
run, tag = metrics.run_tag_from_session_and_metric(
|
|
253
|
+
session_name, metric_name
|
|
254
|
+
)
|
|
255
|
+
metric_runs.add(run)
|
|
256
|
+
metric_tags.add(tag)
|
|
257
|
+
all_metric_evals = (
|
|
258
|
+
self._backend_context.read_last_scalars(
|
|
259
|
+
self._request_context,
|
|
260
|
+
self._experiment_id,
|
|
261
|
+
run_tag_filter=provider.RunTagFilter(
|
|
262
|
+
runs=metric_runs, tags=metric_tags
|
|
263
|
+
),
|
|
264
|
+
)
|
|
265
|
+
if self._include_metrics
|
|
266
|
+
else {}
|
|
267
|
+
)
|
|
268
|
+
for (
|
|
269
|
+
session_name,
|
|
270
|
+
tag_to_content,
|
|
271
|
+
) in hparams_run_to_tag_to_content.items():
|
|
272
|
+
if metadata.SESSION_START_INFO_TAG not in tag_to_content:
|
|
273
|
+
continue
|
|
274
|
+
start_info = metadata.parse_session_start_info_plugin_data(
|
|
275
|
+
tag_to_content[metadata.SESSION_START_INFO_TAG]
|
|
276
|
+
)
|
|
277
|
+
end_info = None
|
|
278
|
+
if metadata.SESSION_END_INFO_TAG in tag_to_content:
|
|
279
|
+
end_info = metadata.parse_session_end_info_plugin_data(
|
|
280
|
+
tag_to_content[metadata.SESSION_END_INFO_TAG]
|
|
281
|
+
)
|
|
282
|
+
session = self._build_session(
|
|
283
|
+
metric_infos,
|
|
284
|
+
session_name,
|
|
285
|
+
start_info,
|
|
286
|
+
end_info,
|
|
287
|
+
all_metric_evals,
|
|
288
|
+
)
|
|
289
|
+
if session.status in self._request.allowed_statuses:
|
|
290
|
+
self._add_session(session, start_info, groups_by_name)
|
|
291
|
+
|
|
292
|
+
# Compute the session group's aggregated metrics for each group.
|
|
293
|
+
groups = groups_by_name.values()
|
|
294
|
+
for group in groups:
|
|
295
|
+
# We sort the sessions in a group so that the order is deterministic.
|
|
296
|
+
group.sessions.sort(key=operator.attrgetter("name"))
|
|
297
|
+
self._aggregate_metrics(group)
|
|
298
|
+
return groups
|
|
299
|
+
|
|
300
|
+
def _add_session(self, session, start_info, groups_by_name):
|
|
301
|
+
"""Adds a new Session protobuffer to the 'groups_by_name' dictionary.
|
|
302
|
+
|
|
303
|
+
Called by _build_session_groups when we encounter a new session. Creates
|
|
304
|
+
the Session protobuffer and adds it to the relevant group in the
|
|
305
|
+
'groups_by_name' dict. Creates the session group if this is the first time
|
|
306
|
+
we encounter it.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
session: api_pb2.Session. The session to add.
|
|
310
|
+
start_info: The SessionStartInfo protobuffer associated with the session.
|
|
311
|
+
groups_by_name: A str to SessionGroup protobuffer dict. Representing the
|
|
312
|
+
session groups and sessions found so far.
|
|
313
|
+
"""
|
|
314
|
+
# If the group_name is empty, this session's group contains only
|
|
315
|
+
# this session. Use the session name for the group name since session
|
|
316
|
+
# names are unique.
|
|
317
|
+
group_name = start_info.group_name or session.name
|
|
318
|
+
if group_name in groups_by_name:
|
|
319
|
+
groups_by_name[group_name].sessions.extend([session])
|
|
320
|
+
else:
|
|
321
|
+
# Create the group and add the session as the first one.
|
|
322
|
+
group = api_pb2.SessionGroup(
|
|
323
|
+
name=group_name,
|
|
324
|
+
sessions=[session],
|
|
325
|
+
monitor_url=start_info.monitor_url,
|
|
326
|
+
)
|
|
327
|
+
# Copy hparams from the first session (all sessions should have the same
|
|
328
|
+
# hyperparameter values) into result.
|
|
329
|
+
# There doesn't seem to be a way to initialize a protobuffer map in the
|
|
330
|
+
# constructor.
|
|
331
|
+
for key, value in start_info.hparams.items():
|
|
332
|
+
if not json_format_compat.is_serializable_value(value):
|
|
333
|
+
# NaN number_value cannot be serialized by higher level layers
|
|
334
|
+
# that are using json_format.MessageToJson(). To workaround
|
|
335
|
+
# the issue we do not copy them to the session group and
|
|
336
|
+
# effectively treat them as "unset".
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
group.hparams[key].CopyFrom(value)
|
|
340
|
+
groups_by_name[group_name] = group
|
|
341
|
+
|
|
342
|
+
def _build_session(
|
|
343
|
+
self, metric_infos, name, start_info, end_info, all_metric_evals
|
|
344
|
+
):
|
|
345
|
+
"""Builds a session object."""
|
|
346
|
+
|
|
347
|
+
assert start_info is not None
|
|
348
|
+
result = api_pb2.Session(
|
|
349
|
+
name=name,
|
|
350
|
+
start_time_secs=start_info.start_time_secs,
|
|
351
|
+
model_uri=start_info.model_uri,
|
|
352
|
+
metric_values=self._build_session_metric_values(
|
|
353
|
+
metric_infos, name, all_metric_evals
|
|
354
|
+
),
|
|
355
|
+
monitor_url=start_info.monitor_url,
|
|
356
|
+
)
|
|
357
|
+
if end_info is not None:
|
|
358
|
+
result.status = end_info.status
|
|
359
|
+
result.end_time_secs = end_info.end_time_secs
|
|
360
|
+
return result
|
|
361
|
+
|
|
362
|
+
def _build_session_metric_values(
|
|
363
|
+
self, metric_infos, session_name, all_metric_evals
|
|
364
|
+
):
|
|
365
|
+
"""Builds the session metric values."""
|
|
366
|
+
|
|
367
|
+
# result is a list of api_pb2.MetricValue instances.
|
|
368
|
+
result = []
|
|
369
|
+
for metric_info in metric_infos:
|
|
370
|
+
metric_name = metric_info.name
|
|
371
|
+
run, tag = metrics.run_tag_from_session_and_metric(
|
|
372
|
+
session_name, metric_name
|
|
373
|
+
)
|
|
374
|
+
datum = all_metric_evals.get(run, {}).get(tag)
|
|
375
|
+
if not datum:
|
|
376
|
+
# It's ok if we don't find the metric in the session.
|
|
377
|
+
# We skip it here. For filtering and sorting purposes its value is None.
|
|
378
|
+
continue
|
|
379
|
+
result.append(
|
|
380
|
+
api_pb2.MetricValue(
|
|
381
|
+
name=metric_name,
|
|
382
|
+
wall_time_secs=datum.wall_time,
|
|
383
|
+
training_step=datum.step,
|
|
384
|
+
value=datum.value,
|
|
385
|
+
)
|
|
386
|
+
)
|
|
387
|
+
return result
|
|
388
|
+
|
|
389
|
+
def _aggregate_metrics(self, session_group):
|
|
390
|
+
"""Sets the metrics of the group based on aggregation_type."""
|
|
391
|
+
|
|
392
|
+
if (
|
|
393
|
+
self._request.aggregation_type == api_pb2.AGGREGATION_AVG
|
|
394
|
+
or self._request.aggregation_type == api_pb2.AGGREGATION_UNSET
|
|
395
|
+
):
|
|
396
|
+
_set_avg_session_metrics(session_group)
|
|
397
|
+
elif self._request.aggregation_type == api_pb2.AGGREGATION_MEDIAN:
|
|
398
|
+
_set_median_session_metrics(
|
|
399
|
+
session_group, self._request.aggregation_metric
|
|
400
|
+
)
|
|
401
|
+
elif self._request.aggregation_type == api_pb2.AGGREGATION_MIN:
|
|
402
|
+
_set_extremum_session_metrics(
|
|
403
|
+
session_group, self._request.aggregation_metric, min
|
|
404
|
+
)
|
|
405
|
+
elif self._request.aggregation_type == api_pb2.AGGREGATION_MAX:
|
|
406
|
+
_set_extremum_session_metrics(
|
|
407
|
+
session_group, self._request.aggregation_metric, max
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
raise error.HParamsError(
|
|
411
|
+
"Unknown aggregation_type in request: %s"
|
|
412
|
+
% self._request.aggregation_type
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
def _filter(self, session_groups, filters):
|
|
416
|
+
return [
|
|
417
|
+
sg for sg in session_groups if self._passes_all_filters(sg, filters)
|
|
418
|
+
]
|
|
419
|
+
|
|
420
|
+
def _passes_all_filters(self, session_group, filters):
|
|
421
|
+
return all(filter_fn(session_group) for filter_fn in filters)
|
|
422
|
+
|
|
423
|
+
def _sort(self, session_groups, extractors):
|
|
424
|
+
"""Sorts 'session_groups' in place according to _request.col_params."""
|
|
425
|
+
|
|
426
|
+
# Sort by session_group name so we have a deterministic order.
|
|
427
|
+
session_groups.sort(key=operator.attrgetter("name"))
|
|
428
|
+
# Sort by lexicographical order of the _request.col_params whose order
|
|
429
|
+
# is not ORDER_UNSPECIFIED. The first such column is the primary sorting
|
|
430
|
+
# key, the second is the secondary sorting key, etc. To achieve that we
|
|
431
|
+
# need to iterate on these columns in reverse order (thus the primary key
|
|
432
|
+
# is the key used in the last sort).
|
|
433
|
+
for col_param, extractor in reversed(
|
|
434
|
+
list(zip(self._request.col_params, extractors))
|
|
435
|
+
):
|
|
436
|
+
if col_param.order == api_pb2.ORDER_UNSPECIFIED:
|
|
437
|
+
continue
|
|
438
|
+
if col_param.order == api_pb2.ORDER_ASC:
|
|
439
|
+
session_groups.sort(
|
|
440
|
+
key=_create_key_func(
|
|
441
|
+
extractor,
|
|
442
|
+
none_is_largest=not col_param.missing_values_first,
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
elif col_param.order == api_pb2.ORDER_DESC:
|
|
446
|
+
session_groups.sort(
|
|
447
|
+
key=_create_key_func(
|
|
448
|
+
extractor,
|
|
449
|
+
none_is_largest=col_param.missing_values_first,
|
|
450
|
+
),
|
|
451
|
+
reverse=True,
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
raise error.HParamsError(
|
|
455
|
+
"Unknown col_param.order given: %s" % col_param
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def _create_response(self, session_groups):
|
|
459
|
+
return api_pb2.ListSessionGroupsResponse(
|
|
460
|
+
session_groups=session_groups[
|
|
461
|
+
self._request.start_index : self._request.start_index
|
|
462
|
+
+ self._request.slice_size
|
|
463
|
+
],
|
|
464
|
+
total_size=len(session_groups),
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _create_key_func(extractor, none_is_largest):
|
|
469
|
+
"""Returns a key_func to be used in list.sort().
|
|
470
|
+
|
|
471
|
+
Returns a key_func to be used in list.sort() that sorts session groups
|
|
472
|
+
by the value extracted by extractor. 'None' extracted values will either
|
|
473
|
+
be considered largest or smallest as specified by the "none_is_largest"
|
|
474
|
+
boolean parameter.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
extractor: An extractor function that extract the key from the session
|
|
478
|
+
group.
|
|
479
|
+
none_is_largest: bool. If true treats 'None's as largest; otherwise
|
|
480
|
+
smallest.
|
|
481
|
+
"""
|
|
482
|
+
if none_is_largest:
|
|
483
|
+
|
|
484
|
+
def key_func_none_is_largest(session_group):
|
|
485
|
+
value = extractor(session_group)
|
|
486
|
+
return (value is None, value)
|
|
487
|
+
|
|
488
|
+
return key_func_none_is_largest
|
|
489
|
+
|
|
490
|
+
def key_func_none_is_smallest(session_group):
|
|
491
|
+
value = extractor(session_group)
|
|
492
|
+
return (value is not None, value)
|
|
493
|
+
|
|
494
|
+
return key_func_none_is_smallest
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
# Extractors. An extractor is a function that extracts some property (a metric
|
|
498
|
+
# or a hyperparameter) from a SessionGroup instance.
|
|
499
|
+
def _create_extractors(col_params):
|
|
500
|
+
"""Creates extractors to extract properties corresponding to 'col_params'.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
col_params: List of ListSessionGroupsRequest.ColParam protobufs.
|
|
504
|
+
Returns:
|
|
505
|
+
A list of extractor functions. The ith element in the
|
|
506
|
+
returned list extracts the column corresponding to the ith element of
|
|
507
|
+
_request.col_params
|
|
508
|
+
"""
|
|
509
|
+
result = []
|
|
510
|
+
for col_param in col_params:
|
|
511
|
+
result.append(_create_extractor(col_param))
|
|
512
|
+
return result
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _create_extractor(col_param):
|
|
516
|
+
if col_param.HasField("metric"):
|
|
517
|
+
return _create_metric_extractor(col_param.metric)
|
|
518
|
+
elif col_param.HasField("hparam"):
|
|
519
|
+
return _create_hparam_extractor(col_param.hparam)
|
|
520
|
+
else:
|
|
521
|
+
raise error.HParamsError(
|
|
522
|
+
'Got ColParam with both "metric" and "hparam" fields unset: %s'
|
|
523
|
+
% col_param
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _create_metric_extractor(metric_name):
|
|
528
|
+
"""Returns function that extracts a metric from a session group or a
|
|
529
|
+
session.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
metric_name: tensorboard.hparams.MetricName protobuffer. Identifies the
|
|
533
|
+
metric to extract from the session group.
|
|
534
|
+
Returns:
|
|
535
|
+
A function that takes a tensorboard.hparams.SessionGroup or
|
|
536
|
+
tensorborad.hparams.Session protobuffer and returns the value of the metric
|
|
537
|
+
identified by 'metric_name' or None if the value doesn't exist.
|
|
538
|
+
"""
|
|
539
|
+
|
|
540
|
+
def extractor_fn(session_or_group):
|
|
541
|
+
metric_value = _find_metric_value(session_or_group, metric_name)
|
|
542
|
+
return metric_value.value if metric_value else None
|
|
543
|
+
|
|
544
|
+
return extractor_fn
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def _find_metric_value(session_or_group, metric_name):
|
|
548
|
+
"""Returns the metric_value for a given metric in a session or session
|
|
549
|
+
group.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
session_or_group: A Session protobuffer or SessionGroup protobuffer.
|
|
553
|
+
metric_name: A MetricName protobuffer. The metric to search for.
|
|
554
|
+
Returns:
|
|
555
|
+
A MetricValue protobuffer representing the value of the given metric or
|
|
556
|
+
None if no such metric was found in session_or_group.
|
|
557
|
+
"""
|
|
558
|
+
# Note: We can speed this up by converting the metric_values field
|
|
559
|
+
# to a dictionary on initialization, to avoid a linear search here. We'll
|
|
560
|
+
# need to wrap the SessionGroup and Session protos in a python object for
|
|
561
|
+
# that.
|
|
562
|
+
for metric_value in session_or_group.metric_values:
|
|
563
|
+
if (
|
|
564
|
+
metric_value.name.tag == metric_name.tag
|
|
565
|
+
and metric_value.name.group == metric_name.group
|
|
566
|
+
):
|
|
567
|
+
return metric_value
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def _create_hparam_extractor(hparam_name):
|
|
571
|
+
"""Returns an extractor function that extracts an hparam from a session
|
|
572
|
+
group.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
hparam_name: str. Identies the hparam to extract from the session group.
|
|
576
|
+
Returns:
|
|
577
|
+
A function that takes a tensorboard.hparams.SessionGroup protobuffer and
|
|
578
|
+
returns the value, as a native Python object, of the hparam identified by
|
|
579
|
+
'hparam_name'.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
def extractor_fn(session_group):
|
|
583
|
+
if hparam_name in session_group.hparams:
|
|
584
|
+
return _value_to_python(session_group.hparams[hparam_name])
|
|
585
|
+
return None
|
|
586
|
+
|
|
587
|
+
return extractor_fn
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
# Filters. A filter is a boolean function that takes a session group and returns
|
|
591
|
+
# True if it should be included in the result. Currently, Filters are functions
|
|
592
|
+
# of a single column value extracted from the session group with a given
|
|
593
|
+
# extractor specified in the construction of the filter.
|
|
594
|
+
def _create_filters(col_params, extractors, *, include_hparam_filters=True):
|
|
595
|
+
"""Creates filters for the given col_params.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
col_params: List of ListSessionGroupsRequest.ColParam protobufs.
|
|
599
|
+
extractors: list of extractor functions of the same length as col_params.
|
|
600
|
+
Each element should extract the column described by the corresponding
|
|
601
|
+
element of col_params.
|
|
602
|
+
include_hparam_filters: bool that indicates whether hparam filters should
|
|
603
|
+
be generated. Defaults to True.
|
|
604
|
+
Returns:
|
|
605
|
+
A list of filter functions. Each corresponding to a single
|
|
606
|
+
col_params.filter oneof field of _request
|
|
607
|
+
"""
|
|
608
|
+
result = []
|
|
609
|
+
for col_param, extractor in zip(col_params, extractors):
|
|
610
|
+
if not include_hparam_filters and col_param.hparam:
|
|
611
|
+
continue
|
|
612
|
+
|
|
613
|
+
a_filter = _create_filter(col_param, extractor)
|
|
614
|
+
if a_filter:
|
|
615
|
+
result.append(a_filter)
|
|
616
|
+
return result
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _create_filter(col_param, extractor):
|
|
620
|
+
"""Creates a filter for the given col_param and extractor.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
col_param: A tensorboard.hparams.ColParams object identifying the column
|
|
624
|
+
and describing the filter to apply.
|
|
625
|
+
extractor: A function that extract the column value identified by
|
|
626
|
+
'col_param' from a tensorboard.hparams.SessionGroup protobuffer.
|
|
627
|
+
Returns:
|
|
628
|
+
A boolean function taking a tensorboard.hparams.SessionGroup protobuffer
|
|
629
|
+
returning True if the session group passes the filter described by
|
|
630
|
+
'col_param'. If col_param does not specify a filter (i.e. any session
|
|
631
|
+
group passes) returns None.
|
|
632
|
+
"""
|
|
633
|
+
include_missing_values = not col_param.exclude_missing_values
|
|
634
|
+
if col_param.HasField("filter_regexp"):
|
|
635
|
+
value_filter_fn = _create_regexp_filter(col_param.filter_regexp)
|
|
636
|
+
elif col_param.HasField("filter_interval"):
|
|
637
|
+
value_filter_fn = _create_interval_filter(col_param.filter_interval)
|
|
638
|
+
elif col_param.HasField("filter_discrete"):
|
|
639
|
+
value_filter_fn = _create_discrete_set_filter(col_param.filter_discrete)
|
|
640
|
+
elif include_missing_values:
|
|
641
|
+
# No 'filter' field and include_missing_values is True.
|
|
642
|
+
# Thus, the resulting filter always returns True, so to optimize for this
|
|
643
|
+
# common case we do not include it in the list of filters to check.
|
|
644
|
+
return None
|
|
645
|
+
else:
|
|
646
|
+
value_filter_fn = lambda _: True
|
|
647
|
+
|
|
648
|
+
def filter_fn(session_group):
|
|
649
|
+
value = extractor(session_group)
|
|
650
|
+
if value is None:
|
|
651
|
+
return include_missing_values
|
|
652
|
+
return value_filter_fn(value)
|
|
653
|
+
|
|
654
|
+
return filter_fn
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def _create_regexp_filter(regex):
|
|
658
|
+
"""Returns a boolean function that filters strings based on a regular exp.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
regex: A string describing the regexp to use.
|
|
662
|
+
Returns:
|
|
663
|
+
A function taking a string and returns True if any of its substrings
|
|
664
|
+
matches regex.
|
|
665
|
+
"""
|
|
666
|
+
# Warning: Note that python's regex library allows inputs that take
|
|
667
|
+
# exponential time. Time-limiting it is difficult. When we move to
|
|
668
|
+
# a true multi-tenant tensorboard server, the regexp implementation here
|
|
669
|
+
# would need to be replaced by something more secure.
|
|
670
|
+
compiled_regex = re.compile(regex)
|
|
671
|
+
|
|
672
|
+
def filter_fn(value):
|
|
673
|
+
if not isinstance(value, str):
|
|
674
|
+
raise error.HParamsError(
|
|
675
|
+
"Cannot use a regexp filter for a value of type %s. Value: %s"
|
|
676
|
+
% (type(value), value)
|
|
677
|
+
)
|
|
678
|
+
return re.search(compiled_regex, value) is not None
|
|
679
|
+
|
|
680
|
+
return filter_fn
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def _create_interval_filter(interval):
|
|
684
|
+
"""Returns a function that checkes whether a number belongs to an interval.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
interval: A tensorboard.hparams.Interval protobuf describing the interval.
|
|
688
|
+
Returns:
|
|
689
|
+
A function taking a number (float or int) that returns True if the number
|
|
690
|
+
belongs to (the closed) 'interval'.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
def filter_fn(value):
|
|
694
|
+
if not isinstance(value, (int, float)):
|
|
695
|
+
raise error.HParamsError(
|
|
696
|
+
"Cannot use an interval filter for a value of type: %s, Value: %s"
|
|
697
|
+
% (type(value), value)
|
|
698
|
+
)
|
|
699
|
+
return interval.min_value <= value and value <= interval.max_value
|
|
700
|
+
|
|
701
|
+
return filter_fn
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _create_discrete_set_filter(discrete_set):
|
|
705
|
+
"""Returns a function that checks whether a value belongs to a set.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
discrete_set: A list of objects representing the set.
|
|
709
|
+
Returns:
|
|
710
|
+
A function taking an object and returns True if its in the set. Membership
|
|
711
|
+
is tested using the Python 'in' operator (thus, equality of distinct
|
|
712
|
+
objects is computed using the '==' operator).
|
|
713
|
+
"""
|
|
714
|
+
|
|
715
|
+
def filter_fn(value):
|
|
716
|
+
return value in discrete_set
|
|
717
|
+
|
|
718
|
+
return filter_fn
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def _value_to_python(value):
|
|
722
|
+
"""Converts a google.protobuf.Value to a native Python object."""
|
|
723
|
+
|
|
724
|
+
assert isinstance(value, struct_pb2.Value)
|
|
725
|
+
field = value.WhichOneof("kind")
|
|
726
|
+
if field == "number_value":
|
|
727
|
+
return value.number_value
|
|
728
|
+
elif field == "string_value":
|
|
729
|
+
return value.string_value
|
|
730
|
+
elif field == "bool_value":
|
|
731
|
+
return value.bool_value
|
|
732
|
+
else:
|
|
733
|
+
raise ValueError("Unknown struct_pb2.Value oneof field set: %s" % field)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
@dataclasses.dataclass(frozen=True)
|
|
737
|
+
class _MetricIdentifier:
|
|
738
|
+
"""An identifier for a metric.
|
|
739
|
+
|
|
740
|
+
As protobuffers are mutable we can't use MetricName directly as a dict's key.
|
|
741
|
+
Instead, we represent MetricName protocol buffer as an immutable dataclass.
|
|
742
|
+
|
|
743
|
+
Attributes:
|
|
744
|
+
group: Metric group corresponding to the dataset on which the model was
|
|
745
|
+
evaluated.
|
|
746
|
+
tag: String tag associated with the metric.
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
group: str
|
|
750
|
+
tag: str
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
class _MetricStats:
|
|
754
|
+
"""A simple class to hold metric stats used in calculating metric averages.
|
|
755
|
+
|
|
756
|
+
Used in _set_avg_session_metrics(). See the comments in that function
|
|
757
|
+
for more details.
|
|
758
|
+
|
|
759
|
+
Attributes:
|
|
760
|
+
total: int. The sum of the metric measurements seen so far.
|
|
761
|
+
count: int. The number of largest-step measuremens seen so far.
|
|
762
|
+
total_step: int. The sum of the steps at which the measurements were taken
|
|
763
|
+
total_wall_time_secs: float. The sum of the wall_time_secs at
|
|
764
|
+
which the measurements were taken.
|
|
765
|
+
"""
|
|
766
|
+
|
|
767
|
+
# We use slots here to catch typos in attributes earlier. Note that this makes
|
|
768
|
+
# this class incompatible with 'pickle'.
|
|
769
|
+
__slots__ = [
|
|
770
|
+
"total",
|
|
771
|
+
"count",
|
|
772
|
+
"total_step",
|
|
773
|
+
"total_wall_time_secs",
|
|
774
|
+
]
|
|
775
|
+
|
|
776
|
+
def __init__(self):
|
|
777
|
+
self.total = 0
|
|
778
|
+
self.count = 0
|
|
779
|
+
self.total_step = 0
|
|
780
|
+
self.total_wall_time_secs = 0.0
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def _set_avg_session_metrics(session_group):
|
|
784
|
+
"""Sets the metrics for the group to be the average of its sessions.
|
|
785
|
+
|
|
786
|
+
The resulting session group metrics consist of the union of metrics across
|
|
787
|
+
the group's sessions. The value of each session group metric is the average
|
|
788
|
+
of that metric values across the sessions in the group. The 'step' and
|
|
789
|
+
'wall_time_secs' fields of the resulting MetricValue field in the session
|
|
790
|
+
group are populated with the corresponding averages (truncated for 'step')
|
|
791
|
+
as well.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
session_group: A SessionGroup protobuffer.
|
|
795
|
+
"""
|
|
796
|
+
assert session_group.sessions, "SessionGroup cannot be empty."
|
|
797
|
+
# Algorithm: Iterate over all (session, metric) pairs and maintain a
|
|
798
|
+
# dict from _MetricIdentifier to _MetricStats objects.
|
|
799
|
+
# Then use the final dict state to compute the average for each metric.
|
|
800
|
+
metric_stats = collections.defaultdict(_MetricStats)
|
|
801
|
+
for session in session_group.sessions:
|
|
802
|
+
for metric_value in session.metric_values:
|
|
803
|
+
metric_name = _MetricIdentifier(
|
|
804
|
+
group=metric_value.name.group, tag=metric_value.name.tag
|
|
805
|
+
)
|
|
806
|
+
stats = metric_stats[metric_name]
|
|
807
|
+
stats.total += metric_value.value
|
|
808
|
+
stats.count += 1
|
|
809
|
+
stats.total_step += metric_value.training_step
|
|
810
|
+
stats.total_wall_time_secs += metric_value.wall_time_secs
|
|
811
|
+
|
|
812
|
+
del session_group.metric_values[:]
|
|
813
|
+
for metric_name, stats in metric_stats.items():
|
|
814
|
+
session_group.metric_values.add(
|
|
815
|
+
name=api_pb2.MetricName(
|
|
816
|
+
group=metric_name.group, tag=metric_name.tag
|
|
817
|
+
),
|
|
818
|
+
value=float(stats.total) / float(stats.count),
|
|
819
|
+
training_step=stats.total_step // stats.count,
|
|
820
|
+
wall_time_secs=stats.total_wall_time_secs / stats.count,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
@dataclasses.dataclass(frozen=True)
|
|
825
|
+
class _Measurement:
|
|
826
|
+
"""Holds a session's metric value.
|
|
827
|
+
|
|
828
|
+
Attributes:
|
|
829
|
+
metric_value: Metric value of the session.
|
|
830
|
+
session_index: Index of the session in its group.
|
|
831
|
+
"""
|
|
832
|
+
|
|
833
|
+
metric_value: Optional[api_pb2.MetricValue]
|
|
834
|
+
session_index: int
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def _set_median_session_metrics(session_group, aggregation_metric):
|
|
838
|
+
"""Sets the metrics for session_group to those of its "median session".
|
|
839
|
+
|
|
840
|
+
The median session is the session in session_group with the median value
|
|
841
|
+
of the metric given by 'aggregation_metric'. The median is taken over the
|
|
842
|
+
subset of sessions in the group whose 'aggregation_metric' was measured
|
|
843
|
+
at the largest training step among the sessions in the group.
|
|
844
|
+
|
|
845
|
+
Args:
|
|
846
|
+
session_group: A SessionGroup protobuffer.
|
|
847
|
+
aggregation_metric: A MetricName protobuffer.
|
|
848
|
+
"""
|
|
849
|
+
measurements = sorted(
|
|
850
|
+
_measurements(session_group, aggregation_metric),
|
|
851
|
+
key=operator.attrgetter("metric_value.value"),
|
|
852
|
+
)
|
|
853
|
+
median_session = measurements[(len(measurements) - 1) // 2].session_index
|
|
854
|
+
del session_group.metric_values[:]
|
|
855
|
+
session_group.metric_values.MergeFrom(
|
|
856
|
+
session_group.sessions[median_session].metric_values
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def _set_extremum_session_metrics(
|
|
861
|
+
session_group, aggregation_metric, extremum_fn
|
|
862
|
+
):
|
|
863
|
+
"""Sets the metrics for session_group to those of its "extremum session".
|
|
864
|
+
|
|
865
|
+
The extremum session is the session in session_group with the extremum value
|
|
866
|
+
of the metric given by 'aggregation_metric'. The extremum is taken over the
|
|
867
|
+
subset of sessions in the group whose 'aggregation_metric' was measured
|
|
868
|
+
at the largest training step among the sessions in the group.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
session_group: A SessionGroup protobuffer.
|
|
872
|
+
aggregation_metric: A MetricName protobuffer.
|
|
873
|
+
extremum_fn: callable. Must be either 'min' or 'max'. Determines the type of
|
|
874
|
+
extremum to compute.
|
|
875
|
+
"""
|
|
876
|
+
measurements = _measurements(session_group, aggregation_metric)
|
|
877
|
+
ext_session = extremum_fn(
|
|
878
|
+
measurements, key=operator.attrgetter("metric_value.value")
|
|
879
|
+
).session_index
|
|
880
|
+
del session_group.metric_values[:]
|
|
881
|
+
session_group.metric_values.MergeFrom(
|
|
882
|
+
session_group.sessions[ext_session].metric_values
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def _measurements(session_group, metric_name):
|
|
887
|
+
"""A generator for the values of the metric across the sessions in the
|
|
888
|
+
group.
|
|
889
|
+
|
|
890
|
+
Args:
|
|
891
|
+
session_group: A SessionGroup protobuffer.
|
|
892
|
+
metric_name: A MetricName protobuffer.
|
|
893
|
+
Yields:
|
|
894
|
+
The next metric value wrapped in a _Measurement instance.
|
|
895
|
+
"""
|
|
896
|
+
for session_index, session in enumerate(session_group.sessions):
|
|
897
|
+
metric_value = _find_metric_value(session, metric_name)
|
|
898
|
+
if not metric_value:
|
|
899
|
+
continue
|
|
900
|
+
yield _Measurement(metric_value, session_index)
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def _build_data_provider_filters(col_params):
|
|
904
|
+
"""Builds HyperparameterFilters from ColParams."""
|
|
905
|
+
filters = []
|
|
906
|
+
for col_param in col_params:
|
|
907
|
+
if not col_param.hparam:
|
|
908
|
+
# We do not pass metric filters to the DataProvider as it does not
|
|
909
|
+
# have the metric data for filtering.
|
|
910
|
+
continue
|
|
911
|
+
|
|
912
|
+
fltr = _build_data_provider_filter(col_param)
|
|
913
|
+
if fltr is None:
|
|
914
|
+
continue
|
|
915
|
+
filters.append(fltr)
|
|
916
|
+
return filters
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _build_data_provider_filter(col_param):
|
|
920
|
+
"""Builds HyperparameterFilter from ColParam.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
col_param: ColParam that possibly contains filter information.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
None if col_param does not specify filter information.
|
|
927
|
+
"""
|
|
928
|
+
if col_param.HasField("filter_regexp"):
|
|
929
|
+
filter_type = provider.HyperparameterFilterType.REGEX
|
|
930
|
+
fltr = col_param.filter_regexp
|
|
931
|
+
elif col_param.HasField("filter_interval"):
|
|
932
|
+
filter_type = provider.HyperparameterFilterType.INTERVAL
|
|
933
|
+
fltr = (
|
|
934
|
+
col_param.filter_interval.min_value,
|
|
935
|
+
col_param.filter_interval.max_value,
|
|
936
|
+
)
|
|
937
|
+
elif col_param.HasField("filter_discrete"):
|
|
938
|
+
filter_type = provider.HyperparameterFilterType.DISCRETE
|
|
939
|
+
fltr = [_value_to_python(b) for b in col_param.filter_discrete.values]
|
|
940
|
+
else:
|
|
941
|
+
return None
|
|
942
|
+
|
|
943
|
+
return provider.HyperparameterFilter(
|
|
944
|
+
hyperparameter_name=col_param.hparam,
|
|
945
|
+
filter_type=filter_type,
|
|
946
|
+
filter=fltr,
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def _build_data_provider_sort(col_params):
|
|
951
|
+
"""Builds HyperparameterSorts from ColParams."""
|
|
952
|
+
sort = []
|
|
953
|
+
for col_param in col_params:
|
|
954
|
+
sort_item = _build_data_provider_sort_item(col_param)
|
|
955
|
+
if sort_item is None:
|
|
956
|
+
continue
|
|
957
|
+
sort.append(sort_item)
|
|
958
|
+
return sort
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
def _build_data_provider_sort_item(col_param):
|
|
962
|
+
"""Builds HyperparameterSort from ColParam.
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
col_param: ColParam that possibly contains sort information.
|
|
966
|
+
|
|
967
|
+
Returns:
|
|
968
|
+
None if col_param does not specify sort information.
|
|
969
|
+
"""
|
|
970
|
+
if col_param.order == api_pb2.ORDER_UNSPECIFIED:
|
|
971
|
+
return None
|
|
972
|
+
|
|
973
|
+
sort_direction = (
|
|
974
|
+
provider.HyperparameterSortDirection.ASCENDING
|
|
975
|
+
if col_param.order == api_pb2.ORDER_ASC
|
|
976
|
+
else provider.HyperparameterSortDirection.DESCENDING
|
|
977
|
+
)
|
|
978
|
+
return provider.HyperparameterSort(
|
|
979
|
+
hyperparameter_name=col_param.hparam,
|
|
980
|
+
sort_direction=sort_direction,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def _specifies_include(col_params):
|
|
985
|
+
"""Determines whether any `ColParam` contains the `include_in_result` field.
|
|
986
|
+
|
|
987
|
+
In the case where none of the col_params contains the field, we should assume
|
|
988
|
+
that all fields should be included in the response.
|
|
989
|
+
"""
|
|
990
|
+
return any(
|
|
991
|
+
col_param.HasField("include_in_result") for col_param in col_params
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
def _get_hparams_to_include(col_params):
|
|
996
|
+
"""Generates the list of hparams to include in the response.
|
|
997
|
+
|
|
998
|
+
The determination is based on the `include_in_result` field in ColParam. If
|
|
999
|
+
a ColParam either has `include_in_result: True` or does not specify the
|
|
1000
|
+
field at all, then it should be included in the result.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
col_params: A collection of `ColParams` protos.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
A list of names of hyperparameters to include in the response.
|
|
1007
|
+
"""
|
|
1008
|
+
hparams_to_include = []
|
|
1009
|
+
for col_param in col_params:
|
|
1010
|
+
if (
|
|
1011
|
+
col_param.HasField("include_in_result")
|
|
1012
|
+
and not col_param.include_in_result
|
|
1013
|
+
):
|
|
1014
|
+
# Explicitly set to exclude this hparam.
|
|
1015
|
+
continue
|
|
1016
|
+
if col_param.hparam:
|
|
1017
|
+
hparams_to_include.append(col_param.hparam)
|
|
1018
|
+
return hparams_to_include
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
def _reduce_to_hparams_to_include(session_groups, col_params):
|
|
1022
|
+
"""Removes hparams from session_groups that should not be included.
|
|
1023
|
+
|
|
1024
|
+
Args:
|
|
1025
|
+
session_groups: A collection of `SessionGroup` protos, which will be
|
|
1026
|
+
modified in place.
|
|
1027
|
+
col_params: A collection of `ColParams` protos.
|
|
1028
|
+
"""
|
|
1029
|
+
hparams_to_include = _get_hparams_to_include(col_params)
|
|
1030
|
+
|
|
1031
|
+
for session_group in session_groups:
|
|
1032
|
+
new_hparams = {
|
|
1033
|
+
hparam: value
|
|
1034
|
+
for (hparam, value) in session_group.hparams.items()
|
|
1035
|
+
if hparam in hparams_to_include
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
session_group.ClearField("hparams")
|
|
1039
|
+
for hparam, value in new_hparams.items():
|
|
1040
|
+
session_group.hparams[hparam].CopyFrom(value)
|