xprof-nightly 2.22.3a20251208__cp311-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xprof/__init__.py +22 -0
- xprof/convert/_pywrap_profiler_plugin.so +0 -0
- xprof/convert/csv_writer.py +87 -0
- xprof/convert/raw_to_tool_data.py +232 -0
- xprof/convert/trace_events_json.py +105 -0
- xprof/integration_tests/tf_mnist.py +100 -0
- xprof/integration_tests/tf_profiler_session.py +40 -0
- xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py +183 -0
- xprof/profile_plugin.py +1521 -0
- xprof/profile_plugin_loader.py +82 -0
- xprof/protobuf/dcn_collective_info_pb2.py +44 -0
- xprof/protobuf/dcn_slack_analysis_pb2.py +42 -0
- xprof/protobuf/diagnostics_pb2.py +36 -0
- xprof/protobuf/event_time_fraction_analyzer_pb2.py +42 -0
- xprof/protobuf/hardware_types_pb2.py +40 -0
- xprof/protobuf/hlo_stats_pb2.py +39 -0
- xprof/protobuf/inference_stats_pb2.py +86 -0
- xprof/protobuf/input_pipeline_pb2.py +52 -0
- xprof/protobuf/kernel_stats_pb2.py +38 -0
- xprof/protobuf/memory_profile_pb2.py +54 -0
- xprof/protobuf/memory_viewer_preprocess_pb2.py +49 -0
- xprof/protobuf/op_metrics_pb2.py +65 -0
- xprof/protobuf/op_profile_pb2.py +49 -0
- xprof/protobuf/op_stats_pb2.py +71 -0
- xprof/protobuf/overview_page_pb2.py +64 -0
- xprof/protobuf/pod_stats_pb2.py +45 -0
- xprof/protobuf/pod_viewer_pb2.py +61 -0
- xprof/protobuf/power_metrics_pb2.py +38 -0
- xprof/protobuf/roofline_model_pb2.py +42 -0
- xprof/protobuf/smart_suggestion_pb2.py +38 -0
- xprof/protobuf/source_info_pb2.py +36 -0
- xprof/protobuf/source_stats_pb2.py +48 -0
- xprof/protobuf/steps_db_pb2.py +76 -0
- xprof/protobuf/task_pb2.py +37 -0
- xprof/protobuf/tf_data_stats_pb2.py +72 -0
- xprof/protobuf/tf_function_pb2.py +52 -0
- xprof/protobuf/tf_stats_pb2.py +40 -0
- xprof/protobuf/tfstreamz_pb2.py +40 -0
- xprof/protobuf/topology_pb2.py +50 -0
- xprof/protobuf/tpu_input_pipeline_pb2.py +43 -0
- xprof/protobuf/trace_events_old_pb2.py +54 -0
- xprof/protobuf/trace_events_pb2.py +64 -0
- xprof/protobuf/trace_events_raw_pb2.py +45 -0
- xprof/protobuf/trace_filter_config_pb2.py +40 -0
- xprof/server.py +319 -0
- xprof/standalone/base_plugin.py +52 -0
- xprof/standalone/context.py +22 -0
- xprof/standalone/data_provider.py +32 -0
- xprof/standalone/plugin_asset_util.py +131 -0
- xprof/standalone/plugin_event_multiplexer.py +185 -0
- xprof/standalone/tensorboard_shim.py +31 -0
- xprof/static/bundle.js +130500 -0
- xprof/static/index.html +64 -0
- xprof/static/index.js +3 -0
- xprof/static/materialicons.woff2 +0 -0
- xprof/static/styles.css +1 -0
- xprof/static/trace_viewer_index.html +3929 -0
- xprof/static/trace_viewer_index.js +15906 -0
- xprof/static/zone.js +3558 -0
- xprof/version.py +17 -0
- xprof_nightly-2.22.3a20251208.dist-info/METADATA +301 -0
- xprof_nightly-2.22.3a20251208.dist-info/RECORD +65 -0
- xprof_nightly-2.22.3a20251208.dist-info/WHEEL +5 -0
- xprof_nightly-2.22.3a20251208.dist-info/entry_points.txt +5 -0
- xprof_nightly-2.22.3a20251208.dist-info/top_level.txt +1 -0
xprof/profile_plugin.py
ADDED
|
@@ -0,0 +1,1521 @@
|
|
|
1
|
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""The TensorBoard plugin for performance profiling."""
|
|
16
|
+
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
from __future__ import division
|
|
19
|
+
from __future__ import print_function
|
|
20
|
+
|
|
21
|
+
from collections.abc import Callable, Iterator, Mapping
|
|
22
|
+
import gzip
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import os
|
|
26
|
+
import re
|
|
27
|
+
import sys
|
|
28
|
+
import threading
|
|
29
|
+
from typing import Any, Dict, List, Optional, Sequence, TypedDict
|
|
30
|
+
|
|
31
|
+
from etils import epath
|
|
32
|
+
import etils.epath.backend
|
|
33
|
+
from fsspec import core
|
|
34
|
+
import six
|
|
35
|
+
from werkzeug import wrappers
|
|
36
|
+
|
|
37
|
+
from xprof import version
|
|
38
|
+
from xprof.convert import raw_to_tool_data as convert
|
|
39
|
+
from xprof.standalone.tensorboard_shim import base_plugin
|
|
40
|
+
from xprof.standalone.tensorboard_shim import plugin_asset_util
|
|
41
|
+
from xprof.convert import _pywrap_profiler_plugin
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger('tensorboard.plugins.profile')
|
|
44
|
+
logger.setLevel(logging.INFO)
|
|
45
|
+
if not logger.handlers:
|
|
46
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
47
|
+
formatter = logging.Formatter(
|
|
48
|
+
'%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s'
|
|
49
|
+
)
|
|
50
|
+
handler.setFormatter(formatter)
|
|
51
|
+
logger.addHandler(handler)
|
|
52
|
+
logger.propagate = False
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
|
|
56
|
+
|
|
57
|
+
tf.enable_v2_behavior()
|
|
58
|
+
except ImportError:
|
|
59
|
+
logger.info(
|
|
60
|
+
'Disabling some remote capture features as tensorflow is not available'
|
|
61
|
+
)
|
|
62
|
+
tf = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# The prefix of routes provided by this plugin.
|
|
66
|
+
TB_NAME = 'plugins'
|
|
67
|
+
PLUGIN_NAME = 'profile'
|
|
68
|
+
|
|
69
|
+
BASE_ROUTE = '/'
|
|
70
|
+
INDEX_JS_ROUTE = '/index.js'
|
|
71
|
+
INDEX_HTML_ROUTE = '/index.html'
|
|
72
|
+
BUNDLE_JS_ROUTE = '/bundle.js'
|
|
73
|
+
STYLES_CSS_ROUTE = '/styles.css'
|
|
74
|
+
MATERIALICONS_WOFF2_ROUTE = '/materialicons.woff2'
|
|
75
|
+
TRACE_VIEWER_INDEX_HTML_ROUTE = '/trace_viewer_index.html'
|
|
76
|
+
TRACE_VIEWER_INDEX_JS_ROUTE = '/trace_viewer_index.js'
|
|
77
|
+
ZONE_JS_ROUTE = '/zone.js'
|
|
78
|
+
DATA_ROUTE = '/data'
|
|
79
|
+
RUNS_ROUTE = '/runs'
|
|
80
|
+
RUN_TOOLS_ROUTE = '/run_tools'
|
|
81
|
+
HOSTS_ROUTE = '/hosts'
|
|
82
|
+
HLO_MODULE_LIST_ROUTE = '/module_list'
|
|
83
|
+
CAPTURE_ROUTE = '/capture_profile'
|
|
84
|
+
LOCAL_ROUTE = '/local'
|
|
85
|
+
CONFIG_ROUTE = '/config'
|
|
86
|
+
CACHE_VERSION_FILE = 'cache_version.txt'
|
|
87
|
+
|
|
88
|
+
# Suffixes of "^, #, @" symbols represent different input data formats for the
|
|
89
|
+
# same tool.
|
|
90
|
+
# 1) '^': data generate from XPlane.
|
|
91
|
+
# 2) '#': data is in gzip format.
|
|
92
|
+
# 3) '@': data generate from proto, or tracetable for streaming trace viewer.
|
|
93
|
+
# 4) no suffix: data is in json format, ready to feed to frontend.
|
|
94
|
+
TOOLS = {
|
|
95
|
+
'xplane': 'xplane.pb',
|
|
96
|
+
'hlo_proto': 'hlo_proto.pb',
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
ALL_HOSTS = 'ALL_HOSTS'
|
|
100
|
+
|
|
101
|
+
HostMetadata = TypedDict('HostMetadata', {'hostname': str})
|
|
102
|
+
|
|
103
|
+
_EXTENSION_TO_TOOL = {extension: tool for tool, extension in TOOLS.items()}
|
|
104
|
+
|
|
105
|
+
_FILENAME_RE = re.compile(r'(?:(.*)\.)?(' +
|
|
106
|
+
'|'.join(TOOLS.values()).replace('.', r'\.') + r')')
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Tools that can be generated from xplane end with ^.
|
|
110
|
+
XPLANE_TOOLS = [
|
|
111
|
+
'trace_viewer', # non-streaming before TF 2.13
|
|
112
|
+
'trace_viewer@', # streaming since TF 2.14
|
|
113
|
+
'overview_page',
|
|
114
|
+
'input_pipeline_analyzer',
|
|
115
|
+
'framework_op_stats',
|
|
116
|
+
'kernel_stats',
|
|
117
|
+
'memory_profile',
|
|
118
|
+
'pod_viewer',
|
|
119
|
+
'op_profile',
|
|
120
|
+
'hlo_stats',
|
|
121
|
+
'roofline_model',
|
|
122
|
+
'inference_profile',
|
|
123
|
+
'memory_viewer',
|
|
124
|
+
'graph_viewer',
|
|
125
|
+
'megascale_stats',
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
# XPlane generated tools that support all host mode.
|
|
129
|
+
XPLANE_TOOLS_ALL_HOSTS_SUPPORTED = frozenset([
|
|
130
|
+
'input_pipeline_analyzer',
|
|
131
|
+
'framework_op_stats',
|
|
132
|
+
'kernel_stats',
|
|
133
|
+
'overview_page',
|
|
134
|
+
'pod_viewer',
|
|
135
|
+
'megascale_stats',
|
|
136
|
+
])
|
|
137
|
+
|
|
138
|
+
# XPlane generated tools that only support all host mode.
|
|
139
|
+
XPLANE_TOOLS_ALL_HOSTS_ONLY = frozenset(
|
|
140
|
+
['overview_page', 'pod_viewer'])
|
|
141
|
+
|
|
142
|
+
# Rate limiter constants, the GCS quota defined below
|
|
143
|
+
# https://cloud.google.com/storage/quotas#rate-quotas.
|
|
144
|
+
# currently set to 1000 request per minute.
|
|
145
|
+
# TODO(kcai): The assumption on the average number of subdirs is not
|
|
146
|
+
# always true. If this is not sufficient, we can consider a token-based
|
|
147
|
+
# approach that counts the number of subdirs after calling iterdir.
|
|
148
|
+
MAX_GCS_REQUESTS = 1000
|
|
149
|
+
LIMIT_WINDOW_SECONDS = 60
|
|
150
|
+
AVERAGE_SUBDIR_NUMBER = 10
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def use_xplane(tool: str) -> bool:
|
|
154
|
+
return tool in XPLANE_TOOLS
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# HLO generated tools.
|
|
158
|
+
HLO_TOOLS = frozenset(['graph_viewer', 'memory_viewer'])
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def use_hlo(tool: str) -> bool:
|
|
162
|
+
return tool in HLO_TOOLS
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def make_filename(host: str, tool: str) -> str:
|
|
166
|
+
"""Returns the name of the file containing data for the given host and tool.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
host: Name of the host that produced the profile data, e.g., 'localhost'.
|
|
170
|
+
tool: Name of the tool, e.g., 'trace_viewer'.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
The host name concatenated with the tool-specific extension, e.g.,
|
|
174
|
+
'localhost.trace'.
|
|
175
|
+
"""
|
|
176
|
+
filename = str(host) + '.' if host else ''
|
|
177
|
+
if use_hlo(tool):
|
|
178
|
+
tool = 'hlo_proto'
|
|
179
|
+
elif use_xplane(tool):
|
|
180
|
+
tool = 'xplane'
|
|
181
|
+
return filename + TOOLS[tool]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _parse_filename(filename: str) -> tuple[Optional[str], Optional[str]]:
|
|
185
|
+
"""Returns the host and tool encoded in a filename in the run directory.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
filename: Name of a file in the run directory. The name might encode a host
|
|
189
|
+
and tool, e.g., 'host.tracetable', 'host.domain.op_profile.json', or just
|
|
190
|
+
a tool, e.g., 'trace', 'tensorflow_stats.pb'.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A tuple (host, tool) containing the names of the host and tool, e.g.,
|
|
194
|
+
('localhost', 'trace_viewer'). Either of the tuple's components can be None.
|
|
195
|
+
"""
|
|
196
|
+
m = _FILENAME_RE.fullmatch(filename)
|
|
197
|
+
if m is None:
|
|
198
|
+
return filename, None
|
|
199
|
+
return m.group(1), _EXTENSION_TO_TOOL[m.group(2)]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _get_hosts(filenames: list[str]) -> set[str]:
|
|
203
|
+
"""Parses a list of filenames and returns the set of hosts.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
filenames: A list of filenames (just basenames, no directory).
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
A set of host names encoded in the filenames.
|
|
210
|
+
"""
|
|
211
|
+
hosts = set()
|
|
212
|
+
for name in filenames:
|
|
213
|
+
host, _ = _parse_filename(name)
|
|
214
|
+
if host:
|
|
215
|
+
hosts.add(host)
|
|
216
|
+
return hosts
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]:
|
|
220
|
+
"""Parses a list of filenames and returns the set of tools.
|
|
221
|
+
|
|
222
|
+
If xplane is present in the repository, add tools that can be generated by
|
|
223
|
+
xplane if we don't have a file for the tool.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
filenames: A list of filenames.
|
|
227
|
+
profile_run_dir: The run directory of the profile.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A set of tool names encoded in the filenames.
|
|
231
|
+
"""
|
|
232
|
+
tools = set()
|
|
233
|
+
found = set()
|
|
234
|
+
xplane_filenames = []
|
|
235
|
+
for name in filenames:
|
|
236
|
+
_, tool = _parse_filename(name)
|
|
237
|
+
if tool == 'xplane':
|
|
238
|
+
xplane_filenames.append(os.path.join(profile_run_dir, name))
|
|
239
|
+
continue
|
|
240
|
+
elif tool == 'hlo_proto':
|
|
241
|
+
continue
|
|
242
|
+
elif tool:
|
|
243
|
+
tools.add(tool)
|
|
244
|
+
if tool[-1] in ('@'):
|
|
245
|
+
found.add(tool[:-1])
|
|
246
|
+
else:
|
|
247
|
+
found.add(tool)
|
|
248
|
+
# profile_run_dir might be empty, like in cloud AI use case.
|
|
249
|
+
if not profile_run_dir:
|
|
250
|
+
if xplane_filenames:
|
|
251
|
+
for item in XPLANE_TOOLS:
|
|
252
|
+
if item[:-1] not in found:
|
|
253
|
+
tools.add(item)
|
|
254
|
+
else:
|
|
255
|
+
try:
|
|
256
|
+
if xplane_filenames:
|
|
257
|
+
return set(convert.xspace_to_tool_names(xplane_filenames))
|
|
258
|
+
except AttributeError:
|
|
259
|
+
logger.warning('XPlane converters are available after Tensorflow 2.4')
|
|
260
|
+
return tools
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def respond(
|
|
264
|
+
body: Any,
|
|
265
|
+
content_type: str,
|
|
266
|
+
code: int = 200,
|
|
267
|
+
content_encoding: Optional[tuple[str, str]] = None,
|
|
268
|
+
) -> wrappers.Response:
|
|
269
|
+
"""Create a Werkzeug response, handling JSON serialization and CSP.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
body: For JSON responses, a JSON-serializable object; otherwise, a raw
|
|
273
|
+
`bytes` string or Unicode `str` (which will be encoded as UTF-8).
|
|
274
|
+
content_type: Response content-type (`str`); use `application/json` to
|
|
275
|
+
automatically serialize structures.
|
|
276
|
+
code: HTTP status code (`int`).
|
|
277
|
+
content_encoding: Response Content-Encoding header ('str'); e.g. 'gzip'. If
|
|
278
|
+
the content type is not set, The data would be compressed and the content
|
|
279
|
+
encoding would be set to gzip.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
A `werkzeug.wrappers.Response` object.
|
|
283
|
+
"""
|
|
284
|
+
if content_type == 'application/json' and isinstance(
|
|
285
|
+
body, (dict, list, set, tuple)):
|
|
286
|
+
body = json.dumps(body, sort_keys=True)
|
|
287
|
+
if not isinstance(body, bytes):
|
|
288
|
+
body = body.encode('utf-8')
|
|
289
|
+
csp_parts = {
|
|
290
|
+
'default-src': ["'self'"],
|
|
291
|
+
'script-src': [
|
|
292
|
+
"'self'",
|
|
293
|
+
"'unsafe-eval'",
|
|
294
|
+
"'unsafe-inline'",
|
|
295
|
+
'https://www.gstatic.com',
|
|
296
|
+
],
|
|
297
|
+
'object-src': ["'none'"],
|
|
298
|
+
'style-src': [
|
|
299
|
+
"'self'",
|
|
300
|
+
"'unsafe-inline'",
|
|
301
|
+
'https://fonts.googleapis.com',
|
|
302
|
+
'https://www.gstatic.com',
|
|
303
|
+
],
|
|
304
|
+
'font-src': [
|
|
305
|
+
"'self'",
|
|
306
|
+
'https://fonts.googleapis.com',
|
|
307
|
+
'https://fonts.gstatic.com',
|
|
308
|
+
'data:',
|
|
309
|
+
],
|
|
310
|
+
'connect-src': [
|
|
311
|
+
"'self'",
|
|
312
|
+
'data:',
|
|
313
|
+
'www.gstatic.com',
|
|
314
|
+
],
|
|
315
|
+
'img-src': [
|
|
316
|
+
"'self'",
|
|
317
|
+
'blob:',
|
|
318
|
+
'data:',
|
|
319
|
+
],
|
|
320
|
+
'script-src-elem': [
|
|
321
|
+
"'self'",
|
|
322
|
+
"'unsafe-inline'",
|
|
323
|
+
# Remember to restrict on integrity when importing from jsdelivr
|
|
324
|
+
# Whitelist this domain to support hlo_graph_dumper html format
|
|
325
|
+
'https://cdn.jsdelivr.net/npm/',
|
|
326
|
+
'https://www.gstatic.com',
|
|
327
|
+
],
|
|
328
|
+
}
|
|
329
|
+
csp = ';'.join((' '.join([k] + v) for (k, v) in csp_parts.items()))
|
|
330
|
+
headers = [
|
|
331
|
+
('Content-Security-Policy', csp),
|
|
332
|
+
('X-Content-Type-Options', 'nosniff'),
|
|
333
|
+
]
|
|
334
|
+
if content_encoding:
|
|
335
|
+
headers.append(('Content-Encoding', content_encoding))
|
|
336
|
+
else:
|
|
337
|
+
headers.append(('Content-Encoding', 'gzip'))
|
|
338
|
+
body = gzip.compress(body)
|
|
339
|
+
return wrappers.Response(
|
|
340
|
+
body, content_type=content_type, status=code, headers=headers
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _plugin_assets(
|
|
345
|
+
session_dir: str, runs: list[str], plugin_name: str
|
|
346
|
+
) -> dict[str, list[str]]:
|
|
347
|
+
result = {}
|
|
348
|
+
for run in runs:
|
|
349
|
+
run_path = _tb_run_directory(session_dir, run)
|
|
350
|
+
assets = plugin_asset_util.ListAssets(run_path, plugin_name)
|
|
351
|
+
result[run] = assets
|
|
352
|
+
return result
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _tb_run_directory(session_dir: str, run: str) -> str:
|
|
356
|
+
"""Returns the TensorBoard run directory for a TensorBoard run name.
|
|
357
|
+
|
|
358
|
+
This helper returns the TensorBoard-level run directory (the one that would)
|
|
359
|
+
contain tfevents files) for a given TensorBoard run name (aka the relative
|
|
360
|
+
path from the session_dir root to this directory). For the root run '.'
|
|
361
|
+
this is the bare session_dir path; for all other runs this is the
|
|
362
|
+
session_dir joined with the run name.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
session_dir: the TensorBoard log directory root path
|
|
366
|
+
run: the TensorBoard run name, e.g. '.' or 'train'
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
The TensorBoard run directory path, e.g. my/session_dir or
|
|
370
|
+
my/session_dir/train.
|
|
371
|
+
"""
|
|
372
|
+
return session_dir if run == '.' else os.path.join(session_dir, run)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def filenames_to_hosts(filenames: list[str], tool: str) -> list[str]:
|
|
376
|
+
"""Convert a list of filenames to a list of host names given a tool.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
filenames: A list of filenames.
|
|
380
|
+
tool: A string representing the profiling tool.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
A list of hostnames.
|
|
384
|
+
"""
|
|
385
|
+
hosts = _get_hosts(filenames)
|
|
386
|
+
if len(hosts) > 1:
|
|
387
|
+
if tool in XPLANE_TOOLS_ALL_HOSTS_ONLY:
|
|
388
|
+
hosts = [ALL_HOSTS]
|
|
389
|
+
elif tool in XPLANE_TOOLS_ALL_HOSTS_SUPPORTED:
|
|
390
|
+
hosts.add(ALL_HOSTS)
|
|
391
|
+
return sorted(hosts)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _get_bool_arg(
|
|
395
|
+
args: Mapping[str, Any], arg_name: str, default: bool
|
|
396
|
+
) -> bool:
|
|
397
|
+
"""Gets a boolean argument from a request.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
args: The werkzeug request arguments.
|
|
401
|
+
arg_name: The name of the argument.
|
|
402
|
+
default: The default value if the argument is not present.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
The boolean value of the argument.
|
|
406
|
+
"""
|
|
407
|
+
arg_str = args.get(arg_name)
|
|
408
|
+
if arg_str is None:
|
|
409
|
+
return default
|
|
410
|
+
return arg_str.lower() == 'true'
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class ToolsCache:
|
|
414
|
+
"""Caches the list of tools for a profile run based on file content hashes or mtimes.
|
|
415
|
+
|
|
416
|
+
Attributes:
|
|
417
|
+
CACHE_FILE_NAME: The name of the cache file.
|
|
418
|
+
CACHE_VERSION: The version of the cache format.
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
CACHE_FILE_NAME = '.cached_tools.json'
|
|
422
|
+
CACHE_VERSION = 1
|
|
423
|
+
|
|
424
|
+
def __init__(self, profile_run_dir: epath.Path):
|
|
425
|
+
"""Initializes the ToolsCache.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
profile_run_dir: The directory containing the profile run data.
|
|
429
|
+
"""
|
|
430
|
+
self._profile_run_dir = profile_run_dir
|
|
431
|
+
self._cache_file = self._profile_run_dir / self.CACHE_FILE_NAME
|
|
432
|
+
logger.info('ToolsCache initialized for %s', self._cache_file)
|
|
433
|
+
|
|
434
|
+
def _get_local_file_identifier(self, file_path_str: str) -> Optional[str]:
|
|
435
|
+
"""Gets a string identifier for a local file.
|
|
436
|
+
|
|
437
|
+
The identifier is a combination of the file's last modification time (mtime)
|
|
438
|
+
and size, in the format "{mtime}-{size}".
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
file_path_str: The absolute path to the local file.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
A string identifier, or None if the file is not found or an error occurs.
|
|
445
|
+
"""
|
|
446
|
+
try:
|
|
447
|
+
stat_result = os.stat(file_path_str)
|
|
448
|
+
return f'{int(stat_result.st_mtime)}-{stat_result.st_size}'
|
|
449
|
+
except FileNotFoundError:
|
|
450
|
+
logger.warning('Local file not found: %s', file_path_str)
|
|
451
|
+
return None
|
|
452
|
+
except OSError as e:
|
|
453
|
+
logger.error(
|
|
454
|
+
'OSError getting stat for local file %s: %r', file_path_str, e
|
|
455
|
+
)
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
def _get_gcs_file_hash(self, file_path_str: str) -> Optional[str]:
|
|
459
|
+
"""Gets the MD5 hash for a GCS file.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
file_path_str: The GCS path (e.g., "gs://bucket/object").
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
The MD5 hash string, or None if the file is not found or an error occurs.
|
|
466
|
+
"""
|
|
467
|
+
try:
|
|
468
|
+
fs = core.get_fs_token_paths(file_path_str)[0]
|
|
469
|
+
info = fs.info(file_path_str)
|
|
470
|
+
md5_hash = info.get('md5Hash')
|
|
471
|
+
|
|
472
|
+
if not isinstance(md5_hash, str):
|
|
473
|
+
logger.warning(
|
|
474
|
+
'Could not find a valid md5Hash string in info for %s: %s',
|
|
475
|
+
file_path_str,
|
|
476
|
+
info,
|
|
477
|
+
)
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
return md5_hash
|
|
481
|
+
|
|
482
|
+
except FileNotFoundError:
|
|
483
|
+
logger.warning('GCS path not found: %s', file_path_str)
|
|
484
|
+
return None
|
|
485
|
+
except IndexError:
|
|
486
|
+
logger.error('Could not get filesystem for GCS path: %s', file_path_str)
|
|
487
|
+
return None
|
|
488
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
489
|
+
logger.exception(
|
|
490
|
+
'Unexpected error getting hash for GCS path %s: %r', file_path_str, e
|
|
491
|
+
)
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
def get_file_identifier(self, file_path_str: str) -> Optional[str]:
|
|
495
|
+
"""Gets a string identifier for a file.
|
|
496
|
+
|
|
497
|
+
For GCS files, this is the MD5 hash.
|
|
498
|
+
For local files, this is a string combining mtime and size.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
file_path_str: The full path to the file (local or GCS).
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
A string identifier, or None if an error occurs.
|
|
505
|
+
"""
|
|
506
|
+
if file_path_str.startswith('gs://'):
|
|
507
|
+
return self._get_gcs_file_hash(file_path_str)
|
|
508
|
+
else:
|
|
509
|
+
return self._get_local_file_identifier(file_path_str)
|
|
510
|
+
|
|
511
|
+
def _get_current_xplane_file_states(self) -> Optional[Dict[str, str]]:
|
|
512
|
+
"""Gets the current state of XPlane files in the profile run directory.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
A dictionary mapping filename to a string identifier (hash or mtime-size),
|
|
516
|
+
or None if any file state cannot be determined.
|
|
517
|
+
"""
|
|
518
|
+
try:
|
|
519
|
+
file_identifiers = {}
|
|
520
|
+
for xplane_file in self._profile_run_dir.glob(f"*.{TOOLS['xplane']}"):
|
|
521
|
+
file_id = self.get_file_identifier(str(xplane_file))
|
|
522
|
+
if file_id is None:
|
|
523
|
+
logger.warning(
|
|
524
|
+
'Could not get identifier for %s, cache will be invalidated.',
|
|
525
|
+
xplane_file,
|
|
526
|
+
)
|
|
527
|
+
return None
|
|
528
|
+
file_identifiers[xplane_file.name] = file_id
|
|
529
|
+
return file_identifiers
|
|
530
|
+
except OSError as e:
|
|
531
|
+
logger.warning('Could not glob files in %s: %r', self._profile_run_dir, e)
|
|
532
|
+
return None
|
|
533
|
+
|
|
534
|
+
def load(self) -> Optional[List[str]]:
|
|
535
|
+
"""Loads the cached list of tools if the cache is valid.
|
|
536
|
+
|
|
537
|
+
The cache is valid if the cache file exists, the version matches, and
|
|
538
|
+
the file states (hashes/mtimes) of the XPlane files have not changed.
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
A list of tool names if the cache is valid, otherwise None.
|
|
542
|
+
"""
|
|
543
|
+
try:
|
|
544
|
+
with self._cache_file.open('r') as f:
|
|
545
|
+
cached_data = json.load(f)
|
|
546
|
+
except (OSError, json.JSONDecodeError) as e:
|
|
547
|
+
logger.warning(
|
|
548
|
+
'Error reading or decoding cache file %s: %r, invalidating.',
|
|
549
|
+
self._cache_file,
|
|
550
|
+
e,
|
|
551
|
+
)
|
|
552
|
+
self.invalidate()
|
|
553
|
+
return None
|
|
554
|
+
|
|
555
|
+
if cached_data.get('version') != self.CACHE_VERSION:
|
|
556
|
+
logger.info(
|
|
557
|
+
'ToolsCache invalid: version mismatch, expected %s, got %s.'
|
|
558
|
+
' Invalidating %s',
|
|
559
|
+
self.CACHE_VERSION,
|
|
560
|
+
cached_data.get('version'),
|
|
561
|
+
self._cache_file,
|
|
562
|
+
)
|
|
563
|
+
self.invalidate()
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
current_files = self._get_current_xplane_file_states()
|
|
567
|
+
if current_files is None:
|
|
568
|
+
logger.info(
|
|
569
|
+
'ToolsCache invalid: could not determine current file states.'
|
|
570
|
+
' Invalidating %s',
|
|
571
|
+
self._cache_file,
|
|
572
|
+
)
|
|
573
|
+
self.invalidate()
|
|
574
|
+
return None
|
|
575
|
+
|
|
576
|
+
if cached_data.get('files') != current_files:
|
|
577
|
+
logger.info(
|
|
578
|
+
'ToolsCache invalid: file states differ. Invalidating %s',
|
|
579
|
+
self._cache_file,
|
|
580
|
+
)
|
|
581
|
+
self.invalidate()
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
logger.info('ToolsCache hit: %s', self._cache_file)
|
|
585
|
+
return cached_data.get('tools')
|
|
586
|
+
|
|
587
|
+
def save(self, tools: Sequence[str]) -> None:
|
|
588
|
+
"""Saves the list of tools and the current file states to the cache file.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
tools: The list of tool names to cache.
|
|
592
|
+
"""
|
|
593
|
+
current_files_for_cache = self._get_current_xplane_file_states()
|
|
594
|
+
if current_files_for_cache is None:
|
|
595
|
+
logger.warning(
|
|
596
|
+
'ToolsCache not saved: could not get file states %s', self._cache_file
|
|
597
|
+
)
|
|
598
|
+
return
|
|
599
|
+
|
|
600
|
+
new_cache_data = {
|
|
601
|
+
'version': self.CACHE_VERSION,
|
|
602
|
+
'files': current_files_for_cache,
|
|
603
|
+
'tools': tools,
|
|
604
|
+
}
|
|
605
|
+
try:
|
|
606
|
+
with self._cache_file.open('w') as f:
|
|
607
|
+
json.dump(new_cache_data, f, sort_keys=True, indent=2)
|
|
608
|
+
logger.info('ToolsCache saved: %s', self._cache_file)
|
|
609
|
+
except (OSError, TypeError) as e:
|
|
610
|
+
logger.error('Error writing cache file %s: %r', self._cache_file, e)
|
|
611
|
+
|
|
612
|
+
def invalidate(self) -> None:
|
|
613
|
+
"""Deletes the cache file, forcing regeneration on the next load."""
|
|
614
|
+
try:
|
|
615
|
+
self._cache_file.unlink()
|
|
616
|
+
logger.info('ToolsCache invalidated: %s', self._cache_file)
|
|
617
|
+
except FileNotFoundError:
|
|
618
|
+
pass
|
|
619
|
+
except OSError as e:
|
|
620
|
+
logger.error('Error removing cache file %s: %r', self._cache_file, e)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
class _TfProfiler:
|
|
624
|
+
"""A helper class to encapsulate all TensorFlow-dependent profiler logic."""
|
|
625
|
+
|
|
626
|
+
def __init__(self, tf_module):
|
|
627
|
+
if not tf_module:
|
|
628
|
+
raise ImportError('TensorFlow module is not available.')
|
|
629
|
+
self.tf = tf_module
|
|
630
|
+
|
|
631
|
+
def _get_worker_list(self, cluster_resolver) -> str:
|
|
632
|
+
"""Parses TPU workers list from the cluster resolver."""
|
|
633
|
+
cluster_spec = cluster_resolver.cluster_spec()
|
|
634
|
+
task_indices = cluster_spec.task_indices('worker')
|
|
635
|
+
worker_list = [
|
|
636
|
+
cluster_spec.task_address('worker', i).replace(':8470', ':8466')
|
|
637
|
+
for i in task_indices
|
|
638
|
+
]
|
|
639
|
+
return ','.join(worker_list)
|
|
640
|
+
|
|
641
|
+
def resolve_tpu_name(
|
|
642
|
+
self, tpu_name: str, worker_list: str
|
|
643
|
+
) -> tuple[str, str, str]:
|
|
644
|
+
"""Resolves a TPU name to its master IP, service address, and worker list.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
tpu_name: The name of the TPU to resolve.
|
|
648
|
+
worker_list: A comma-separated list of worker addresses.
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
A tuple containing (service_addr, worker_list, master_ip).
|
|
652
|
+
"""
|
|
653
|
+
try:
|
|
654
|
+
resolver = self.tf.distribute.cluster_resolver.TPUClusterResolver(
|
|
655
|
+
tpu_name
|
|
656
|
+
)
|
|
657
|
+
master_grpc_addr = resolver.get_master()
|
|
658
|
+
except RuntimeError as err:
|
|
659
|
+
# Propagate error to be handled by the caller.
|
|
660
|
+
raise RuntimeError(
|
|
661
|
+
f'Error initializing TPUClusterResolver: {err}'
|
|
662
|
+
) from err
|
|
663
|
+
except (ValueError, TypeError) as e:
|
|
664
|
+
# Handle cases where the TPU name is invalid.
|
|
665
|
+
raise ValueError(f'No TPU found with the name: {tpu_name}') from e
|
|
666
|
+
|
|
667
|
+
if not worker_list:
|
|
668
|
+
worker_list = self._get_worker_list(resolver)
|
|
669
|
+
|
|
670
|
+
# TPU cluster resolver always returns port 8470. Replace it with 8466
|
|
671
|
+
# on which profiler service is running.
|
|
672
|
+
master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '')
|
|
673
|
+
service_addr = f'{master_ip}:8466'
|
|
674
|
+
return service_addr, worker_list, master_ip
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
class ProfilePlugin(base_plugin.TBPlugin):
|
|
678
|
+
"""Profile Plugin for TensorBoard."""
|
|
679
|
+
plugin_name = PLUGIN_NAME
|
|
680
|
+
|
|
681
|
+
def __init__(self, context):
|
|
682
|
+
"""Constructs a profiler plugin for TensorBoard.
|
|
683
|
+
|
|
684
|
+
This plugin adds handlers for performance-related frontends.
|
|
685
|
+
Args:
|
|
686
|
+
context: A base_plugin.TBContext instance.
|
|
687
|
+
"""
|
|
688
|
+
self.logdir = context.logdir
|
|
689
|
+
self.data_provider = context.data_provider
|
|
690
|
+
self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel
|
|
691
|
+
self.hide_capture_profile_button = getattr(
|
|
692
|
+
context, 'hide_capture_profile_button', False
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# Whether the plugin is active. This is an expensive computation, so we
|
|
696
|
+
# compute this asynchronously and cache positive results indefinitely.
|
|
697
|
+
self._is_active = False
|
|
698
|
+
# Lock to ensure at most one thread computes _is_active at a time.
|
|
699
|
+
self._is_active_lock = threading.Lock()
|
|
700
|
+
# Cache to map profile run name to corresponding tensorboard dir name
|
|
701
|
+
self._run_to_profile_run_dir = {}
|
|
702
|
+
self._tf_profiler = _TfProfiler(tf) if tf else None
|
|
703
|
+
|
|
704
|
+
def is_active(self) -> bool:
|
|
705
|
+
"""Whether this plugin is active and has any profile data to show.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
Whether any run has profile data.
|
|
709
|
+
"""
|
|
710
|
+
if not self._is_active:
|
|
711
|
+
self._is_active = any(self.generate_runs())
|
|
712
|
+
return self._is_active
|
|
713
|
+
|
|
714
|
+
def _does_tool_support_multi_hosts_processing(self, tool: str) -> bool:
|
|
715
|
+
"""Returns true if the tool supports multi-hosts processing."""
|
|
716
|
+
return tool == 'trace_viewer@' or tool == 'trace_viewer'
|
|
717
|
+
|
|
718
|
+
def get_plugin_apps(
|
|
719
|
+
self,
|
|
720
|
+
) -> dict[str, Callable[[wrappers.Request], wrappers.Response]]:
|
|
721
|
+
return {
|
|
722
|
+
BASE_ROUTE: self.default_handler,
|
|
723
|
+
INDEX_JS_ROUTE: self.static_file_route,
|
|
724
|
+
INDEX_HTML_ROUTE: self.static_file_route,
|
|
725
|
+
BUNDLE_JS_ROUTE: self.static_file_route,
|
|
726
|
+
STYLES_CSS_ROUTE: self.static_file_route,
|
|
727
|
+
MATERIALICONS_WOFF2_ROUTE: self.static_file_route,
|
|
728
|
+
TRACE_VIEWER_INDEX_HTML_ROUTE: self.static_file_route,
|
|
729
|
+
TRACE_VIEWER_INDEX_JS_ROUTE: self.static_file_route,
|
|
730
|
+
ZONE_JS_ROUTE: self.static_file_route,
|
|
731
|
+
RUNS_ROUTE: self.runs_route,
|
|
732
|
+
RUN_TOOLS_ROUTE: self.run_tools_route,
|
|
733
|
+
HOSTS_ROUTE: self.hosts_route,
|
|
734
|
+
DATA_ROUTE: self.data_route,
|
|
735
|
+
HLO_MODULE_LIST_ROUTE: self.hlo_module_list_route,
|
|
736
|
+
CAPTURE_ROUTE: self.capture_route,
|
|
737
|
+
LOCAL_ROUTE: self.default_handler,
|
|
738
|
+
CONFIG_ROUTE: self.config_route,
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
# pytype: disable=wrong-arg-types
|
|
742
|
+
@wrappers.Request.application
|
|
743
|
+
def default_handler(self, _: wrappers.Request) -> wrappers.Response:
|
|
744
|
+
contents = self._read_static_file_impl('index.html')
|
|
745
|
+
return respond(contents, 'text/html')
|
|
746
|
+
|
|
747
|
+
# pytype: disable=wrong-arg-types
|
|
748
|
+
@wrappers.Request.application
|
|
749
|
+
def config_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
750
|
+
# pytype: enable=wrong-arg-types
|
|
751
|
+
"""Returns UI configuration details."""
|
|
752
|
+
logger.info('config_route: %s', self.logdir)
|
|
753
|
+
config_data = {
|
|
754
|
+
'hideCaptureProfileButton': self.hide_capture_profile_button,
|
|
755
|
+
}
|
|
756
|
+
return respond(config_data, 'application/json')
|
|
757
|
+
|
|
758
|
+
def frontend_metadata(self):
|
|
759
|
+
return base_plugin.FrontendMetadata(es_module_path='/index.js')
|
|
760
|
+
|
|
761
|
+
def _read_static_file_impl(self, filename: str) -> bytes:
|
|
762
|
+
"""Reads contents from a filename.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
filename (str): Name of the file.
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
Contents of the file.
|
|
769
|
+
Raises:
|
|
770
|
+
IOError: File could not be read or found.
|
|
771
|
+
"""
|
|
772
|
+
filepath = os.path.join(os.path.dirname(__file__), 'static', filename)
|
|
773
|
+
|
|
774
|
+
try:
|
|
775
|
+
with open(filepath, 'rb') as infile:
|
|
776
|
+
contents = infile.read()
|
|
777
|
+
except IOError as io_error:
|
|
778
|
+
raise io_error
|
|
779
|
+
return contents
|
|
780
|
+
|
|
781
|
+
# pytype: disable=wrong-arg-types
|
|
782
|
+
@wrappers.Request.application
|
|
783
|
+
def static_file_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
784
|
+
# pytype: enable=wrong-arg-types
|
|
785
|
+
filename = os.path.basename(request.path)
|
|
786
|
+
extention = os.path.splitext(filename)[1]
|
|
787
|
+
if extention == '.html':
|
|
788
|
+
mimetype = 'text/html'
|
|
789
|
+
elif extention == '.css':
|
|
790
|
+
mimetype = 'text/css'
|
|
791
|
+
elif extention == '.js':
|
|
792
|
+
mimetype = 'application/javascript'
|
|
793
|
+
else:
|
|
794
|
+
mimetype = 'application/octet-stream'
|
|
795
|
+
try:
|
|
796
|
+
contents = self._read_static_file_impl(filename)
|
|
797
|
+
except IOError:
|
|
798
|
+
return respond('Fail to read the files.', 'text/plain', code=404)
|
|
799
|
+
return respond(contents, mimetype)
|
|
800
|
+
|
|
801
|
+
# pytype: disable=wrong-arg-types
|
|
802
|
+
@wrappers.Request.application
|
|
803
|
+
def runs_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
804
|
+
# pytype: enable=wrong-arg-types
|
|
805
|
+
runs = self.runs_imp(request)
|
|
806
|
+
return respond(runs, 'application/json')
|
|
807
|
+
|
|
808
|
+
def _run_map_from_request(
|
|
809
|
+
self, request: Optional[wrappers.Request] = None
|
|
810
|
+
) -> Optional[dict[str, str]]:
|
|
811
|
+
"""Returns a map of run names to session directories from the request.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
request: Optional; werkzeug request used for grabbing session_path and
|
|
815
|
+
run_path arguments.
|
|
816
|
+
"""
|
|
817
|
+
session_path_arg = request.args.get('session_path') if request else None
|
|
818
|
+
run_path_arg = (
|
|
819
|
+
request.args.get('run_path')
|
|
820
|
+
if request and not session_path_arg
|
|
821
|
+
else None
|
|
822
|
+
)
|
|
823
|
+
run_map = None
|
|
824
|
+
if session_path_arg:
|
|
825
|
+
session_path = epath.Path(session_path_arg)
|
|
826
|
+
run_name = session_path.name
|
|
827
|
+
run_map = {}
|
|
828
|
+
if session_path.is_dir() and any(session_path.glob('*.xplane.pb')):
|
|
829
|
+
run_map[run_name] = str(session_path)
|
|
830
|
+
elif run_path_arg:
|
|
831
|
+
run_path = epath.Path(run_path_arg)
|
|
832
|
+
run_map = {}
|
|
833
|
+
for session in run_path.iterdir():
|
|
834
|
+
if session.is_dir() and any(session.glob('*.xplane.pb')):
|
|
835
|
+
run_map[session.name] = str(session)
|
|
836
|
+
return run_map
|
|
837
|
+
|
|
838
|
+
def _run_dir(
|
|
839
|
+
self, run: str, request: Optional[wrappers.Request] = None
|
|
840
|
+
) -> Optional[str]:
|
|
841
|
+
"""Helper that maps a frontend run name to a profile "run" directory.
|
|
842
|
+
|
|
843
|
+
The frontend run name consists of the TensorBoard run name (aka the relative
|
|
844
|
+
path from the logdir root to the directory containing the data) path-joined
|
|
845
|
+
to the Profile plugin's "run" concept (which is a subdirectory of the
|
|
846
|
+
plugins/profile directory representing an individual run of the tool), with
|
|
847
|
+
the special case that TensorBoard run is the logdir root (which is the run
|
|
848
|
+
named '.') then only the Profile plugin "run" name is used, for backwards
|
|
849
|
+
compatibility.
|
|
850
|
+
|
|
851
|
+
Args:
|
|
852
|
+
run: the frontend run name, as described above, e.g. train/run1.
|
|
853
|
+
request: Optional; werkzeug request used for grabbing session_path and
|
|
854
|
+
run_path arguments.
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
The resolved directory path, e.g. /logdir/train/plugins/profile/run1.
|
|
858
|
+
|
|
859
|
+
Raises:
|
|
860
|
+
ValueError: If the run is not found in the run map.
|
|
861
|
+
RuntimeError: If the run directory is not found.
|
|
862
|
+
"""
|
|
863
|
+
run_map = self._run_map_from_request(request)
|
|
864
|
+
if run_map is not None:
|
|
865
|
+
if run in run_map:
|
|
866
|
+
return run_map[run]
|
|
867
|
+
else:
|
|
868
|
+
raise ValueError(f'Run {run} not found in run map: {run_map}')
|
|
869
|
+
|
|
870
|
+
if run in self._run_to_profile_run_dir:
|
|
871
|
+
return self._run_to_profile_run_dir[run]
|
|
872
|
+
|
|
873
|
+
if not self.logdir:
|
|
874
|
+
raise RuntimeError(
|
|
875
|
+
'No matching run directory for run %s. Logdir is empty.' % run
|
|
876
|
+
)
|
|
877
|
+
tb_run_name, profile_run_name = os.path.split(run.rstrip(os.sep))
|
|
878
|
+
if not tb_run_name:
|
|
879
|
+
tb_run_name = '.'
|
|
880
|
+
tb_run_directory = _tb_run_directory(self.logdir, tb_run_name)
|
|
881
|
+
if not epath.Path(tb_run_directory).is_dir():
|
|
882
|
+
raise RuntimeError('No matching run directory for run %s' % run)
|
|
883
|
+
plugin_directory = plugin_asset_util.PluginDirectory(
|
|
884
|
+
tb_run_directory, PLUGIN_NAME
|
|
885
|
+
)
|
|
886
|
+
return os.path.join(plugin_directory, profile_run_name)
|
|
887
|
+
|
|
888
|
+
def runs_imp(self, request: Optional[wrappers.Request] = None) -> list[str]:
|
|
889
|
+
"""Returns a list all runs for the profile plugin.
|
|
890
|
+
|
|
891
|
+
Args:
|
|
892
|
+
request: Optional; werkzeug request used for grabbing ctx and experiment
|
|
893
|
+
id for other host implementations
|
|
894
|
+
"""
|
|
895
|
+
run_map = self._run_map_from_request(request)
|
|
896
|
+
if run_map is not None:
|
|
897
|
+
runs = run_map.keys()
|
|
898
|
+
else:
|
|
899
|
+
runs = self.generate_runs()
|
|
900
|
+
return sorted(runs, reverse=True)
|
|
901
|
+
|
|
902
|
+
# pytype: disable=wrong-arg-types
|
|
903
|
+
@wrappers.Request.application
|
|
904
|
+
def run_tools_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
905
|
+
# pytype: enable=wrong-arg-types
|
|
906
|
+
run = request.args.get('run')
|
|
907
|
+
run_tools = self.run_tools_imp(run, request)
|
|
908
|
+
return respond(run_tools, 'application/json')
|
|
909
|
+
|
|
910
|
+
def run_tools_imp(
|
|
911
|
+
self, run, request: Optional[wrappers.Request] = None
|
|
912
|
+
) -> list[str]:
|
|
913
|
+
"""Returns a list of tools given a single run.
|
|
914
|
+
|
|
915
|
+
Args:
|
|
916
|
+
run: the frontend run name, item is list returned by runs_imp
|
|
917
|
+
request: Optional; werkzeug request used for grabbing ctx and experiment
|
|
918
|
+
id for other host implementations
|
|
919
|
+
"""
|
|
920
|
+
run_dir = self._run_dir(run, request)
|
|
921
|
+
return list(self.generate_tools_of_run(run, run_dir))
|
|
922
|
+
|
|
923
|
+
def _run_host_impl(
|
|
924
|
+
self, run: str, run_dir: str, tool: str
|
|
925
|
+
) -> List[HostMetadata]:
|
|
926
|
+
if not run_dir:
|
|
927
|
+
logger.warning('Cannot find asset directory for: %s', run)
|
|
928
|
+
return []
|
|
929
|
+
tool_pattern = '*.xplane.pb'
|
|
930
|
+
filenames = []
|
|
931
|
+
try:
|
|
932
|
+
path = epath.Path(run_dir)
|
|
933
|
+
filenames = path.glob(tool_pattern)
|
|
934
|
+
except OSError as e:
|
|
935
|
+
logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)
|
|
936
|
+
filenames = [os.fspath(os.path.basename(f)) for f in filenames]
|
|
937
|
+
|
|
938
|
+
return [{'hostname': host} for host in filenames_to_hosts(filenames, tool)]
|
|
939
|
+
|
|
940
|
+
def host_impl(
|
|
941
|
+
self, run: str, tool: str, request: Optional[wrappers.Request] = None
|
|
942
|
+
) -> List[HostMetadata]:
|
|
943
|
+
"""Returns available hosts and their metadata for the run and tool in the log directory.
|
|
944
|
+
|
|
945
|
+
In the plugin log directory, each directory contains profile data for a
|
|
946
|
+
single run (identified by the directory name), and files in the run
|
|
947
|
+
directory contains data for different tools and hosts. The file that
|
|
948
|
+
contains profile for a specific tool "x" will have extension TOOLS["x"].
|
|
949
|
+
|
|
950
|
+
Example:
|
|
951
|
+
log/
|
|
952
|
+
run1/
|
|
953
|
+
plugins/
|
|
954
|
+
profile/
|
|
955
|
+
host1.trace
|
|
956
|
+
host2.trace
|
|
957
|
+
module1.hlo_proto.pb
|
|
958
|
+
module2.hlo_proto.pb
|
|
959
|
+
run2/
|
|
960
|
+
plugins/
|
|
961
|
+
profile/
|
|
962
|
+
host1.trace
|
|
963
|
+
host2.trace
|
|
964
|
+
|
|
965
|
+
Args:
|
|
966
|
+
run: the frontend run name, e.g., 'run1' or 'run2' for the example above.
|
|
967
|
+
tool: the requested tool, e.g., 'trace_viewer' for the example above.
|
|
968
|
+
request: Optional; werkzeug request used for grabbing ctx and experiment
|
|
969
|
+
id for other host implementations
|
|
970
|
+
|
|
971
|
+
Returns:
|
|
972
|
+
A list of host names, e.g.:
|
|
973
|
+
host_impl(run1, trace_viewer) --> [{"hostname": "host1"}, {"hostname":
|
|
974
|
+
"host2"}]
|
|
975
|
+
host_impl(run1, memory_viewer) --> [{"hostname": "module1"},
|
|
976
|
+
{"hostname":
|
|
977
|
+
"module2"}]
|
|
978
|
+
"""
|
|
979
|
+
run_dir = self._run_dir(run, request)
|
|
980
|
+
return self._run_host_impl(run, run_dir, tool)
|
|
981
|
+
|
|
982
|
+
# pytype: disable=wrong-arg-types
|
|
983
|
+
@wrappers.Request.application
|
|
984
|
+
def hosts_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
985
|
+
# pytype: enable=wrong-arg-types
|
|
986
|
+
run = request.args.get('run')
|
|
987
|
+
tool = request.args.get('tag')
|
|
988
|
+
hosts = self.host_impl(run, tool, request)
|
|
989
|
+
return respond(hosts, 'application/json')
|
|
990
|
+
|
|
991
|
+
# pytype: disable=wrong-arg-types
|
|
992
|
+
@wrappers.Request.application
|
|
993
|
+
def hlo_module_list_route(
|
|
994
|
+
self, request: wrappers.Request
|
|
995
|
+
) -> wrappers.Response:
|
|
996
|
+
module_names_str = self.hlo_module_list_impl(request)
|
|
997
|
+
return respond(module_names_str, 'text/plain')
|
|
998
|
+
|
|
999
|
+
def _get_valid_hosts(
|
|
1000
|
+
self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
|
|
1001
|
+
) -> tuple[List[str], List[epath.Path]]:
|
|
1002
|
+
"""Retrieves and validates the hosts and asset paths for a run and tool.
|
|
1003
|
+
|
|
1004
|
+
Args:
|
|
1005
|
+
run_dir: The run directory.
|
|
1006
|
+
run: The frontend run name.
|
|
1007
|
+
tool: The requested tool.
|
|
1008
|
+
hosts_param: Comma-separated list of selected hosts.
|
|
1009
|
+
host: The single host parameter.
|
|
1010
|
+
|
|
1011
|
+
Returns:
|
|
1012
|
+
A tuple containing (selected_hosts, asset_paths).
|
|
1013
|
+
|
|
1014
|
+
Raises:
|
|
1015
|
+
FileNotFoundError: If a required xplane file for the specified host(s)
|
|
1016
|
+
is not found.
|
|
1017
|
+
IOError: If there is an error reading asset directories.
|
|
1018
|
+
"""
|
|
1019
|
+
asset_paths = []
|
|
1020
|
+
selected_hosts = []
|
|
1021
|
+
all_xplane_files = {} # Map host to path
|
|
1022
|
+
|
|
1023
|
+
# Find all available xplane files for the run and map them by host.
|
|
1024
|
+
file_pattern = make_filename('*', 'xplane')
|
|
1025
|
+
try:
|
|
1026
|
+
path = epath.Path(run_dir)
|
|
1027
|
+
for xplane_path in path.glob(file_pattern):
|
|
1028
|
+
host_name, _ = _parse_filename(xplane_path.name)
|
|
1029
|
+
if host_name:
|
|
1030
|
+
all_xplane_files[host_name] = xplane_path
|
|
1031
|
+
except OSError as e:
|
|
1032
|
+
logger.warning('Cannot read asset directory: %s, OpError %r', run_dir, e)
|
|
1033
|
+
raise IOError(
|
|
1034
|
+
'Cannot read asset directory: %s, OpError %r' % (run_dir, e)
|
|
1035
|
+
) from e
|
|
1036
|
+
|
|
1037
|
+
if hosts_param and self._does_tool_support_multi_hosts_processing(tool):
|
|
1038
|
+
selected_hosts = hosts_param.split(',')
|
|
1039
|
+
for selected_host in selected_hosts:
|
|
1040
|
+
if selected_host in all_xplane_files:
|
|
1041
|
+
asset_paths.append(all_xplane_files[selected_host])
|
|
1042
|
+
else:
|
|
1043
|
+
raise FileNotFoundError(
|
|
1044
|
+
'No xplane file found for host: %s in run: %s'
|
|
1045
|
+
% (selected_host, run)
|
|
1046
|
+
)
|
|
1047
|
+
logger.info('Inside trace_viewer@, asset_paths: %s')
|
|
1048
|
+
elif host == ALL_HOSTS:
|
|
1049
|
+
asset_paths = list(all_xplane_files.values())
|
|
1050
|
+
selected_hosts = list(all_xplane_files.keys())
|
|
1051
|
+
elif host and host in all_xplane_files:
|
|
1052
|
+
selected_hosts = [host]
|
|
1053
|
+
asset_paths = [all_xplane_files[host]]
|
|
1054
|
+
elif host:
|
|
1055
|
+
logger.warning('No xplane file found for host: %s in run: %s', host, run)
|
|
1056
|
+
if host not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
|
|
1057
|
+
raise FileNotFoundError(
|
|
1058
|
+
'No xplane file found for host: %s in run: %s' % (host, run)
|
|
1059
|
+
)
|
|
1060
|
+
elif not host and not hosts_param and len(all_xplane_files) == 1:
|
|
1061
|
+
selected_hosts = list(all_xplane_files.keys())
|
|
1062
|
+
asset_paths = list(all_xplane_files.values())
|
|
1063
|
+
|
|
1064
|
+
if not asset_paths:
|
|
1065
|
+
logger.warning(
|
|
1066
|
+
'No matching asset paths found for run %s, tool %s, host(s) %s / %s',
|
|
1067
|
+
run,
|
|
1068
|
+
tool,
|
|
1069
|
+
hosts_param,
|
|
1070
|
+
host,
|
|
1071
|
+
)
|
|
1072
|
+
if not host and tool not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
|
|
1073
|
+
raise FileNotFoundError(
|
|
1074
|
+
'Host must be specified for tool %s in run %s' % (tool, run)
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
return selected_hosts, asset_paths
|
|
1078
|
+
|
|
1079
|
+
def data_impl(
|
|
1080
|
+
self, request: wrappers.Request
|
|
1081
|
+
) -> tuple[Optional[str], str, Optional[str]]:
|
|
1082
|
+
"""Retrieves and processes the tool data for a run and a host.
|
|
1083
|
+
|
|
1084
|
+
Args:
|
|
1085
|
+
request: XMLHttpRequest
|
|
1086
|
+
|
|
1087
|
+
Returns:
|
|
1088
|
+
A string that can be served to the frontend tool or None if tool,
|
|
1089
|
+
run or host is invalid.
|
|
1090
|
+
|
|
1091
|
+
Raises:
|
|
1092
|
+
FileNotFoundError: If a required xplane file for the specified host(s)
|
|
1093
|
+
is not found.
|
|
1094
|
+
IOError: If there is an error reading asset directories.
|
|
1095
|
+
AttributeError: If there is an error during xplane to tool data conversion
|
|
1096
|
+
ValueError: If xplane conversion fails due to invalid data.
|
|
1097
|
+
"""
|
|
1098
|
+
run = request.args.get('run')
|
|
1099
|
+
tool = request.args.get('tag')
|
|
1100
|
+
hosts_param = request.args.get('hosts')
|
|
1101
|
+
host = request.args.get('host')
|
|
1102
|
+
module_name = request.args.get('module_name')
|
|
1103
|
+
tqx = request.args.get('tqx')
|
|
1104
|
+
use_saved_result = _get_bool_arg(request.args, 'use_saved_result', True)
|
|
1105
|
+
full_dma = _get_bool_arg(request.args, 'full_dma', False)
|
|
1106
|
+
run_dir = self._run_dir(run, request)
|
|
1107
|
+
|
|
1108
|
+
# Check if the cache file exists and if the cache file version is less
|
|
1109
|
+
# than the current plugin version, clear the cache.
|
|
1110
|
+
try:
|
|
1111
|
+
if epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).exists():
|
|
1112
|
+
with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
|
|
1113
|
+
'r'
|
|
1114
|
+
) as f:
|
|
1115
|
+
cache_version = f.read().strip()
|
|
1116
|
+
if cache_version < version.__version__:
|
|
1117
|
+
use_saved_result = False
|
|
1118
|
+
else:
|
|
1119
|
+
use_saved_result = False
|
|
1120
|
+
except OSError as e:
|
|
1121
|
+
logger.warning('Cannot read cache version file: %r', e)
|
|
1122
|
+
use_saved_result = False
|
|
1123
|
+
|
|
1124
|
+
graph_viewer_options = self._get_graph_viewer_options(request)
|
|
1125
|
+
# Host param is used by HLO tools to identify the module.
|
|
1126
|
+
params = {
|
|
1127
|
+
'graph_viewer_options': graph_viewer_options,
|
|
1128
|
+
'tqx': tqx,
|
|
1129
|
+
'host': host,
|
|
1130
|
+
'module_name': module_name,
|
|
1131
|
+
'use_saved_result': use_saved_result,
|
|
1132
|
+
}
|
|
1133
|
+
if request.args.get('group_by'):
|
|
1134
|
+
params['group_by'] = request.args.get('group_by')
|
|
1135
|
+
content_type = 'application/json'
|
|
1136
|
+
|
|
1137
|
+
if tool not in TOOLS and not use_xplane(tool):
|
|
1138
|
+
return None, content_type, None
|
|
1139
|
+
if tool == 'memory_viewer' and request.args.get(
|
|
1140
|
+
'view_memory_allocation_timeline'
|
|
1141
|
+
):
|
|
1142
|
+
params['view_memory_allocation_timeline'] = True
|
|
1143
|
+
|
|
1144
|
+
params['memory_space'] = request.args.get('memory_space', '0')
|
|
1145
|
+
|
|
1146
|
+
if tool == 'trace_viewer@':
|
|
1147
|
+
options = {}
|
|
1148
|
+
options['resolution'] = request.args.get('resolution', 8000)
|
|
1149
|
+
options['full_dma'] = full_dma
|
|
1150
|
+
if request.args.get('start_time_ms') is not None:
|
|
1151
|
+
options['start_time_ms'] = request.args.get('start_time_ms')
|
|
1152
|
+
if request.args.get('end_time_ms') is not None:
|
|
1153
|
+
options['end_time_ms'] = request.args.get('end_time_ms')
|
|
1154
|
+
if request.args.get('event_name') is not None:
|
|
1155
|
+
options['event_name'] = request.args.get('event_name')
|
|
1156
|
+
if request.args.get('duration_ms') is not None:
|
|
1157
|
+
options['duration_ms'] = request.args.get('duration_ms')
|
|
1158
|
+
if request.args.get('unique_id') is not None:
|
|
1159
|
+
options['unique_id'] = request.args.get('unique_id')
|
|
1160
|
+
if request.args.get('search_prefix') is not None:
|
|
1161
|
+
options['search_prefix'] = request.args.get('search_prefix')
|
|
1162
|
+
params['trace_viewer_options'] = options
|
|
1163
|
+
|
|
1164
|
+
_, content_encoding = None, None
|
|
1165
|
+
if use_xplane(tool):
|
|
1166
|
+
selected_hosts, asset_paths = self._get_valid_hosts(
|
|
1167
|
+
run_dir, run, tool, hosts_param, host
|
|
1168
|
+
)
|
|
1169
|
+
if not asset_paths:
|
|
1170
|
+
return None, content_type, None
|
|
1171
|
+
|
|
1172
|
+
params['hosts'] = selected_hosts
|
|
1173
|
+
try:
|
|
1174
|
+
data, content_type = convert.xspace_to_tool_data(
|
|
1175
|
+
asset_paths, tool, params)
|
|
1176
|
+
except AttributeError as e:
|
|
1177
|
+
logger.warning('Error generating analysis results due to %r', e)
|
|
1178
|
+
raise AttributeError(
|
|
1179
|
+
'Error generating analysis results due to %r' % e
|
|
1180
|
+
) from e
|
|
1181
|
+
except ValueError as e:
|
|
1182
|
+
logger.warning('XPlane convert to tool data failed as %r', e)
|
|
1183
|
+
raise e
|
|
1184
|
+
except FileNotFoundError as e:
|
|
1185
|
+
logger.warning('XPlane convert to tool data failed as %r', e)
|
|
1186
|
+
raise e
|
|
1187
|
+
|
|
1188
|
+
# Write cache version file if use_saved_result is False.
|
|
1189
|
+
if not use_saved_result:
|
|
1190
|
+
try:
|
|
1191
|
+
with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
|
|
1192
|
+
'w'
|
|
1193
|
+
) as f:
|
|
1194
|
+
f.write(version.__version__)
|
|
1195
|
+
except OSError as e:
|
|
1196
|
+
logger.warning('Cannot write cache version file: %r', e)
|
|
1197
|
+
|
|
1198
|
+
return data, content_type, content_encoding
|
|
1199
|
+
|
|
1200
|
+
logger.info('%s does not use xplane', tool)
|
|
1201
|
+
return None, content_type, None
|
|
1202
|
+
|
|
1203
|
+
def hlo_module_list_impl(
|
|
1204
|
+
self, request: wrappers.Request
|
|
1205
|
+
) -> str:
|
|
1206
|
+
"""Returns a string of HLO module names concatenated by comma for the given run."""
|
|
1207
|
+
run = request.args.get('run')
|
|
1208
|
+
run_dir = self._run_dir(run, request)
|
|
1209
|
+
module_list = []
|
|
1210
|
+
if not run_dir:
|
|
1211
|
+
logger.warning('Cannot find asset directory for: %s', run)
|
|
1212
|
+
return ''
|
|
1213
|
+
tool_pattern = '*.hlo_proto.pb'
|
|
1214
|
+
filenames = []
|
|
1215
|
+
try:
|
|
1216
|
+
path = epath.Path(run_dir)
|
|
1217
|
+
filenames = path.glob(tool_pattern)
|
|
1218
|
+
except OSError as e:
|
|
1219
|
+
logger.warning('Cannot read asset directory: %s, OpError %r', run_dir, e)
|
|
1220
|
+
filenames = [os.fspath(os.path.basename(f)) for f in filenames]
|
|
1221
|
+
for filename in filenames:
|
|
1222
|
+
module_name, _ = _parse_filename(filename)
|
|
1223
|
+
if module_name:
|
|
1224
|
+
module_list.append(module_name)
|
|
1225
|
+
module_names_str = ','.join(module_list)
|
|
1226
|
+
return module_names_str
|
|
1227
|
+
|
|
1228
|
+
# pytype: disable=wrong-arg-types
|
|
1229
|
+
@wrappers.Request.application
|
|
1230
|
+
def data_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
1231
|
+
# pytype: enable=wrong-arg-types
|
|
1232
|
+
# params
|
|
1233
|
+
# request: XMLHTTPRequest.
|
|
1234
|
+
try:
|
|
1235
|
+
data, content_type, content_encoding = self.data_impl(request)
|
|
1236
|
+
if data is None:
|
|
1237
|
+
return respond('No Data', 'text/plain', code=404)
|
|
1238
|
+
return respond(data, content_type, content_encoding=content_encoding)
|
|
1239
|
+
# Data fetch error handler
|
|
1240
|
+
except TimeoutError as e:
|
|
1241
|
+
return respond(str(e), 'text/plain', code=500)
|
|
1242
|
+
except AttributeError as e:
|
|
1243
|
+
return respond(str(e), 'text/plain', code=500)
|
|
1244
|
+
except ValueError as e:
|
|
1245
|
+
return respond(str(e), 'text/plain', code=500)
|
|
1246
|
+
except FileNotFoundError as e:
|
|
1247
|
+
return respond(str(e), 'text/plain', code=500)
|
|
1248
|
+
except IOError as e:
|
|
1249
|
+
return respond(str(e), 'text/plain', code=500)
|
|
1250
|
+
|
|
1251
|
+
# pytype: disable=wrong-arg-types
|
|
1252
|
+
@wrappers.Request.application
|
|
1253
|
+
def capture_route(self, request: wrappers.Request) -> wrappers.Response:
|
|
1254
|
+
# pytype: enable=wrong-arg-types
|
|
1255
|
+
return self.capture_route_impl(request)
|
|
1256
|
+
|
|
1257
|
+
def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response:
|
|
1258
|
+
"""Runs the client trace for capturing profiling information."""
|
|
1259
|
+
service_addr = request.args.get('service_addr')
|
|
1260
|
+
duration = int(request.args.get('duration', '1000'))
|
|
1261
|
+
is_tpu_name = request.args.get('is_tpu_name') == 'true'
|
|
1262
|
+
worker_list = request.args.get('worker_list')
|
|
1263
|
+
num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1
|
|
1264
|
+
options = {
|
|
1265
|
+
'host_tracer_level': int(request.args.get('host_tracer_level', '2')),
|
|
1266
|
+
'device_tracer_level': int(
|
|
1267
|
+
request.args.get('device_tracer_level', '1')
|
|
1268
|
+
),
|
|
1269
|
+
'python_tracer_level': int(
|
|
1270
|
+
request.args.get('python_tracer_level', '0')
|
|
1271
|
+
),
|
|
1272
|
+
'delay_ms': int(request.args.get('delay', '0')),
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
if is_tpu_name:
|
|
1276
|
+
if not self._tf_profiler:
|
|
1277
|
+
return respond(
|
|
1278
|
+
{
|
|
1279
|
+
'error': (
|
|
1280
|
+
'TensorFlow is not installed, but is required to use TPU'
|
|
1281
|
+
' names.'
|
|
1282
|
+
)
|
|
1283
|
+
},
|
|
1284
|
+
'application/json',
|
|
1285
|
+
code=500,
|
|
1286
|
+
)
|
|
1287
|
+
try:
|
|
1288
|
+
# Delegate to the helper class for all TF-related logic.
|
|
1289
|
+
service_addr, worker_list, master_ip = (
|
|
1290
|
+
self._tf_profiler.resolve_tpu_name(service_addr, worker_list or '')
|
|
1291
|
+
)
|
|
1292
|
+
self.master_tpu_unsecure_channel = master_ip
|
|
1293
|
+
except (RuntimeError, ValueError) as err:
|
|
1294
|
+
return respond({'error': str(err)}, 'application/json', code=500)
|
|
1295
|
+
|
|
1296
|
+
if not self.logdir:
|
|
1297
|
+
return respond(
|
|
1298
|
+
{'error': 'logdir is not set, abort capturing.'},
|
|
1299
|
+
'application/json',
|
|
1300
|
+
code=500,
|
|
1301
|
+
)
|
|
1302
|
+
try:
|
|
1303
|
+
# The core trace call remains, now with cleanly resolved parameters.
|
|
1304
|
+
_pywrap_profiler_plugin.trace(
|
|
1305
|
+
service_addr.removeprefix('grpc://'),
|
|
1306
|
+
str(self.logdir),
|
|
1307
|
+
worker_list,
|
|
1308
|
+
True,
|
|
1309
|
+
duration,
|
|
1310
|
+
num_tracing_attempts,
|
|
1311
|
+
options,
|
|
1312
|
+
)
|
|
1313
|
+
return respond(
|
|
1314
|
+
{'result': 'Capture profile successfully. Please refresh.'},
|
|
1315
|
+
'application/json',
|
|
1316
|
+
)
|
|
1317
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1318
|
+
return respond({'error': str(e)}, 'application/json', code=500)
|
|
1319
|
+
|
|
1320
|
+
def _get_graph_viewer_options(
|
|
1321
|
+
self, request: wrappers.Request
|
|
1322
|
+
) -> dict[str, Any]:
|
|
1323
|
+
node_name = request.args.get('node_name')
|
|
1324
|
+
module_name = request.args.get('module_name')
|
|
1325
|
+
graph_width_str = request.args.get('graph_width') or ''
|
|
1326
|
+
graph_width = int(graph_width_str) if graph_width_str.isdigit() else 3
|
|
1327
|
+
show_metadata = int(request.args.get('show_metadata') == 'true')
|
|
1328
|
+
merge_fusion = int(request.args.get('merge_fusion') == 'true')
|
|
1329
|
+
return {
|
|
1330
|
+
'node_name': node_name,
|
|
1331
|
+
'module_name': module_name,
|
|
1332
|
+
'graph_width': graph_width,
|
|
1333
|
+
'show_metadata': show_metadata,
|
|
1334
|
+
'merge_fusion': merge_fusion,
|
|
1335
|
+
'format': request.args.get('format'),
|
|
1336
|
+
'type': request.args.get('type')
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1339
|
+
def generate_runs(self) -> Iterator[str]:
|
|
1340
|
+
"""Generator for a list of runs.
|
|
1341
|
+
|
|
1342
|
+
The "run name" here is a "frontend run name" - see _tb_run_directory() for
|
|
1343
|
+
the definition of a "frontend run name" and how it maps to a directory of
|
|
1344
|
+
profile data for a specific profile "run". The profile plugin concept of
|
|
1345
|
+
"run" is different from the normal TensorBoard run; each run in this case
|
|
1346
|
+
represents a single instance of profile data collection, more similar to a
|
|
1347
|
+
"step" of data in typical TensorBoard semantics. These runs reside in
|
|
1348
|
+
subdirectories of the plugins/profile directory within any regular
|
|
1349
|
+
TensorBoard run directory or within the session_dir root directory
|
|
1350
|
+
itself (even if it contains no tfevents file and would thus not be
|
|
1351
|
+
considered a normal TensorBoard run, for backwards compatibility).
|
|
1352
|
+
|
|
1353
|
+
`generate_runs` will get all runs first, and get tools list from
|
|
1354
|
+
`generate_tools_of_run` for a single run due to expensive processing for
|
|
1355
|
+
xspace data to parse the tools.
|
|
1356
|
+
Example:
|
|
1357
|
+
logs/
|
|
1358
|
+
plugins/
|
|
1359
|
+
profile/
|
|
1360
|
+
run1/
|
|
1361
|
+
hostA.trace
|
|
1362
|
+
train/
|
|
1363
|
+
events.out.tfevents.foo
|
|
1364
|
+
plugins/
|
|
1365
|
+
profile/
|
|
1366
|
+
run1/
|
|
1367
|
+
hostA.trace
|
|
1368
|
+
hostB.trace
|
|
1369
|
+
run2/
|
|
1370
|
+
hostA.trace
|
|
1371
|
+
validation/
|
|
1372
|
+
events.out.tfevents.foo
|
|
1373
|
+
plugins/
|
|
1374
|
+
profile/
|
|
1375
|
+
run1/
|
|
1376
|
+
hostA.trace
|
|
1377
|
+
new_job/
|
|
1378
|
+
tensorboard/
|
|
1379
|
+
plugins/
|
|
1380
|
+
profile/
|
|
1381
|
+
run1/
|
|
1382
|
+
hostA.xplane.pb
|
|
1383
|
+
Yields:
|
|
1384
|
+
A sequence of string that are "frontend run names".
|
|
1385
|
+
For the above example, this would be:
|
|
1386
|
+
"run1", "train/run1", "train/run2", "validation/run1",
|
|
1387
|
+
"new_job/tensorboard/run1"
|
|
1388
|
+
"""
|
|
1389
|
+
if not self.logdir:
|
|
1390
|
+
return
|
|
1391
|
+
|
|
1392
|
+
# Ensure that we check the root logdir and all subdirectories.
|
|
1393
|
+
# Note that we check if logdir is a directory to handle case where
|
|
1394
|
+
# it's actually a multipart directory spec, which this plugin does not
|
|
1395
|
+
# support.
|
|
1396
|
+
#
|
|
1397
|
+
# This change still enforce the requirement that the subdirectories must
|
|
1398
|
+
# end with plugins/profile directory, as enforced by TensorBoard.
|
|
1399
|
+
logdir_path = epath.Path(self.logdir)
|
|
1400
|
+
schemeless_logdir = str(logdir_path)
|
|
1401
|
+
if '://' in schemeless_logdir:
|
|
1402
|
+
schemeless_logdir = schemeless_logdir.split('://', 1)[1]
|
|
1403
|
+
tb_runs = {'.'}
|
|
1404
|
+
|
|
1405
|
+
if logdir_path.is_dir():
|
|
1406
|
+
try:
|
|
1407
|
+
fs = etils.epath.backend.fsspec_backend.fs(self.logdir)
|
|
1408
|
+
for path_str in fs.glob(os.path.join(self.logdir, '**', PLUGIN_NAME)):
|
|
1409
|
+
path = epath.Path(path_str)
|
|
1410
|
+
if fs.isdir(path) and path.parent.name == TB_NAME:
|
|
1411
|
+
tb_run_dir = path.parent.parent
|
|
1412
|
+
tb_run = tb_run_dir.relative_to(schemeless_logdir)
|
|
1413
|
+
tb_runs.add(str(tb_run))
|
|
1414
|
+
except ValueError:
|
|
1415
|
+
# gcsfs not available, fall back to legacy path walk.
|
|
1416
|
+
for cur_dir, _, _ in logdir_path.walk():
|
|
1417
|
+
if (cur_dir.name == PLUGIN_NAME and cur_dir.parent.name == TB_NAME):
|
|
1418
|
+
tb_run_dir = cur_dir.parent.parent
|
|
1419
|
+
tb_run = tb_run_dir.relative_to(logdir_path)
|
|
1420
|
+
tb_runs.add(str(tb_run))
|
|
1421
|
+
tb_run_names_to_dirs = {
|
|
1422
|
+
run: _tb_run_directory(self.logdir, run) for run in tb_runs
|
|
1423
|
+
}
|
|
1424
|
+
plugin_assets = _plugin_assets(
|
|
1425
|
+
self.logdir, list(tb_run_names_to_dirs), PLUGIN_NAME
|
|
1426
|
+
)
|
|
1427
|
+
visited_runs = set()
|
|
1428
|
+
for tb_run_name, profile_runs in six.iteritems(plugin_assets):
|
|
1429
|
+
tb_run_dir = tb_run_names_to_dirs[tb_run_name]
|
|
1430
|
+
tb_plugin_dir = plugin_asset_util.PluginDirectory(tb_run_dir, PLUGIN_NAME)
|
|
1431
|
+
|
|
1432
|
+
for profile_run in profile_runs:
|
|
1433
|
+
# Remove trailing separator; some filesystem implementations emit this.
|
|
1434
|
+
profile_run = profile_run.rstrip(os.sep)
|
|
1435
|
+
if tb_run_name == '.':
|
|
1436
|
+
frontend_run = profile_run
|
|
1437
|
+
else:
|
|
1438
|
+
frontend_run = str(epath.Path(tb_run_name) / profile_run)
|
|
1439
|
+
profile_run_dir = str(epath.Path(tb_plugin_dir) / profile_run)
|
|
1440
|
+
if epath.Path(profile_run_dir).is_dir():
|
|
1441
|
+
self._run_to_profile_run_dir[frontend_run] = profile_run_dir
|
|
1442
|
+
if frontend_run not in visited_runs:
|
|
1443
|
+
visited_runs.add(frontend_run)
|
|
1444
|
+
yield frontend_run
|
|
1445
|
+
|
|
1446
|
+
def generate_tools_of_run(self, run: str, run_dir: str) -> Iterator[str]:
|
|
1447
|
+
"""Generate a list of tools given a certain run."""
|
|
1448
|
+
if not run_dir:
|
|
1449
|
+
logger.warning('Cannot find asset directory for: %s', run)
|
|
1450
|
+
return
|
|
1451
|
+
profile_run_dir = epath.Path(run_dir)
|
|
1452
|
+
cache = ToolsCache(profile_run_dir)
|
|
1453
|
+
|
|
1454
|
+
cached_tools = cache.load()
|
|
1455
|
+
|
|
1456
|
+
if cached_tools is not None:
|
|
1457
|
+
for tool in cached_tools:
|
|
1458
|
+
yield tool
|
|
1459
|
+
return
|
|
1460
|
+
|
|
1461
|
+
# Cache is invalid or doesn't exist, regenerate
|
|
1462
|
+
tools = []
|
|
1463
|
+
try:
|
|
1464
|
+
all_filenames = [f.name for f in profile_run_dir.iterdir()]
|
|
1465
|
+
except OSError as e:
|
|
1466
|
+
logger.warning(
|
|
1467
|
+
'Cannot read asset directory: %s, Error %r', profile_run_dir, e
|
|
1468
|
+
)
|
|
1469
|
+
return tools
|
|
1470
|
+
|
|
1471
|
+
if all_filenames:
|
|
1472
|
+
tools = self._get_active_tools(all_filenames, str(profile_run_dir))
|
|
1473
|
+
cache.save(tools)
|
|
1474
|
+
|
|
1475
|
+
for tool in tools:
|
|
1476
|
+
yield tool
|
|
1477
|
+
|
|
1478
|
+
def _get_active_tools(self, filenames, profile_run_dir=''):
|
|
1479
|
+
"""Get a list of tools available given the filenames created by profiler.
|
|
1480
|
+
|
|
1481
|
+
Args:
|
|
1482
|
+
filenames: List of strings that represent filenames
|
|
1483
|
+
profile_run_dir: The run directory of the profile.
|
|
1484
|
+
|
|
1485
|
+
Returns:
|
|
1486
|
+
A list of strings representing the available tools
|
|
1487
|
+
"""
|
|
1488
|
+
tool_sort_order = [
|
|
1489
|
+
'overview_page',
|
|
1490
|
+
'trace_viewer',
|
|
1491
|
+
'trace_viewer@',
|
|
1492
|
+
'graph_viewer',
|
|
1493
|
+
'op_profile',
|
|
1494
|
+
'hlo_op_profile',
|
|
1495
|
+
'input_pipeline_analyzer',
|
|
1496
|
+
'input_pipeline',
|
|
1497
|
+
'kernel_stats',
|
|
1498
|
+
'memory_profile',
|
|
1499
|
+
'memory_viewer',
|
|
1500
|
+
'roofline_model',
|
|
1501
|
+
'perf_counters',
|
|
1502
|
+
'pod_viewer',
|
|
1503
|
+
'framework_op_stats',
|
|
1504
|
+
'tensorflow_stats', # Legacy name for framework_op_stats
|
|
1505
|
+
'hlo_op_stats',
|
|
1506
|
+
'hlo_stats', # Legacy name for hlo_op_stats
|
|
1507
|
+
'inference_profile',
|
|
1508
|
+
'megascale_stats',
|
|
1509
|
+
]
|
|
1510
|
+
tools = _get_tools(filenames, profile_run_dir)
|
|
1511
|
+
if 'trace_viewer@' in tools:
|
|
1512
|
+
# streaming trace viewer always override normal trace viewer.
|
|
1513
|
+
# the trailing '@' is to inform tf-profile-dashboard.html and
|
|
1514
|
+
# tf-trace-viewer.html that stream trace viewer should be used.
|
|
1515
|
+
tools.discard('trace_viewer')
|
|
1516
|
+
|
|
1517
|
+
sorted_tools = [t for t in tool_sort_order if t in tools]
|
|
1518
|
+
remaining_tools = tools.difference(sorted_tools)
|
|
1519
|
+
sorted_tools.extend(sorted(remaining_tools))
|
|
1520
|
+
|
|
1521
|
+
return sorted_tools
|