xprof-nightly 2.22.3a20251208__cp311-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. xprof/__init__.py +22 -0
  2. xprof/convert/_pywrap_profiler_plugin.so +0 -0
  3. xprof/convert/csv_writer.py +87 -0
  4. xprof/convert/raw_to_tool_data.py +232 -0
  5. xprof/convert/trace_events_json.py +105 -0
  6. xprof/integration_tests/tf_mnist.py +100 -0
  7. xprof/integration_tests/tf_profiler_session.py +40 -0
  8. xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py +183 -0
  9. xprof/profile_plugin.py +1521 -0
  10. xprof/profile_plugin_loader.py +82 -0
  11. xprof/protobuf/dcn_collective_info_pb2.py +44 -0
  12. xprof/protobuf/dcn_slack_analysis_pb2.py +42 -0
  13. xprof/protobuf/diagnostics_pb2.py +36 -0
  14. xprof/protobuf/event_time_fraction_analyzer_pb2.py +42 -0
  15. xprof/protobuf/hardware_types_pb2.py +40 -0
  16. xprof/protobuf/hlo_stats_pb2.py +39 -0
  17. xprof/protobuf/inference_stats_pb2.py +86 -0
  18. xprof/protobuf/input_pipeline_pb2.py +52 -0
  19. xprof/protobuf/kernel_stats_pb2.py +38 -0
  20. xprof/protobuf/memory_profile_pb2.py +54 -0
  21. xprof/protobuf/memory_viewer_preprocess_pb2.py +49 -0
  22. xprof/protobuf/op_metrics_pb2.py +65 -0
  23. xprof/protobuf/op_profile_pb2.py +49 -0
  24. xprof/protobuf/op_stats_pb2.py +71 -0
  25. xprof/protobuf/overview_page_pb2.py +64 -0
  26. xprof/protobuf/pod_stats_pb2.py +45 -0
  27. xprof/protobuf/pod_viewer_pb2.py +61 -0
  28. xprof/protobuf/power_metrics_pb2.py +38 -0
  29. xprof/protobuf/roofline_model_pb2.py +42 -0
  30. xprof/protobuf/smart_suggestion_pb2.py +38 -0
  31. xprof/protobuf/source_info_pb2.py +36 -0
  32. xprof/protobuf/source_stats_pb2.py +48 -0
  33. xprof/protobuf/steps_db_pb2.py +76 -0
  34. xprof/protobuf/task_pb2.py +37 -0
  35. xprof/protobuf/tf_data_stats_pb2.py +72 -0
  36. xprof/protobuf/tf_function_pb2.py +52 -0
  37. xprof/protobuf/tf_stats_pb2.py +40 -0
  38. xprof/protobuf/tfstreamz_pb2.py +40 -0
  39. xprof/protobuf/topology_pb2.py +50 -0
  40. xprof/protobuf/tpu_input_pipeline_pb2.py +43 -0
  41. xprof/protobuf/trace_events_old_pb2.py +54 -0
  42. xprof/protobuf/trace_events_pb2.py +64 -0
  43. xprof/protobuf/trace_events_raw_pb2.py +45 -0
  44. xprof/protobuf/trace_filter_config_pb2.py +40 -0
  45. xprof/server.py +319 -0
  46. xprof/standalone/base_plugin.py +52 -0
  47. xprof/standalone/context.py +22 -0
  48. xprof/standalone/data_provider.py +32 -0
  49. xprof/standalone/plugin_asset_util.py +131 -0
  50. xprof/standalone/plugin_event_multiplexer.py +185 -0
  51. xprof/standalone/tensorboard_shim.py +31 -0
  52. xprof/static/bundle.js +130500 -0
  53. xprof/static/index.html +64 -0
  54. xprof/static/index.js +3 -0
  55. xprof/static/materialicons.woff2 +0 -0
  56. xprof/static/styles.css +1 -0
  57. xprof/static/trace_viewer_index.html +3929 -0
  58. xprof/static/trace_viewer_index.js +15906 -0
  59. xprof/static/zone.js +3558 -0
  60. xprof/version.py +17 -0
  61. xprof_nightly-2.22.3a20251208.dist-info/METADATA +301 -0
  62. xprof_nightly-2.22.3a20251208.dist-info/RECORD +65 -0
  63. xprof_nightly-2.22.3a20251208.dist-info/WHEEL +5 -0
  64. xprof_nightly-2.22.3a20251208.dist-info/entry_points.txt +5 -0
  65. xprof_nightly-2.22.3a20251208.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1521 @@
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """The TensorBoard plugin for performance profiling."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ from collections.abc import Callable, Iterator, Mapping
22
+ import gzip
23
+ import json
24
+ import logging
25
+ import os
26
+ import re
27
+ import sys
28
+ import threading
29
+ from typing import Any, Dict, List, Optional, Sequence, TypedDict
30
+
31
+ from etils import epath
32
+ import etils.epath.backend
33
+ from fsspec import core
34
+ import six
35
+ from werkzeug import wrappers
36
+
37
+ from xprof import version
38
+ from xprof.convert import raw_to_tool_data as convert
39
+ from xprof.standalone.tensorboard_shim import base_plugin
40
+ from xprof.standalone.tensorboard_shim import plugin_asset_util
41
+ from xprof.convert import _pywrap_profiler_plugin
42
+
43
+ logger = logging.getLogger('tensorboard.plugins.profile')
44
+ logger.setLevel(logging.INFO)
45
+ if not logger.handlers:
46
+ handler = logging.StreamHandler(sys.stderr)
47
+ formatter = logging.Formatter(
48
+ '%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s'
49
+ )
50
+ handler.setFormatter(formatter)
51
+ logger.addHandler(handler)
52
+ logger.propagate = False
53
+
54
+ try:
55
+ import tensorflow.compat.v2 as tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
56
+
57
+ tf.enable_v2_behavior()
58
+ except ImportError:
59
+ logger.info(
60
+ 'Disabling some remote capture features as tensorflow is not available'
61
+ )
62
+ tf = None
63
+
64
+
65
+ # The prefix of routes provided by this plugin.
66
+ TB_NAME = 'plugins'
67
+ PLUGIN_NAME = 'profile'
68
+
69
+ BASE_ROUTE = '/'
70
+ INDEX_JS_ROUTE = '/index.js'
71
+ INDEX_HTML_ROUTE = '/index.html'
72
+ BUNDLE_JS_ROUTE = '/bundle.js'
73
+ STYLES_CSS_ROUTE = '/styles.css'
74
+ MATERIALICONS_WOFF2_ROUTE = '/materialicons.woff2'
75
+ TRACE_VIEWER_INDEX_HTML_ROUTE = '/trace_viewer_index.html'
76
+ TRACE_VIEWER_INDEX_JS_ROUTE = '/trace_viewer_index.js'
77
+ ZONE_JS_ROUTE = '/zone.js'
78
+ DATA_ROUTE = '/data'
79
+ RUNS_ROUTE = '/runs'
80
+ RUN_TOOLS_ROUTE = '/run_tools'
81
+ HOSTS_ROUTE = '/hosts'
82
+ HLO_MODULE_LIST_ROUTE = '/module_list'
83
+ CAPTURE_ROUTE = '/capture_profile'
84
+ LOCAL_ROUTE = '/local'
85
+ CONFIG_ROUTE = '/config'
86
+ CACHE_VERSION_FILE = 'cache_version.txt'
87
+
88
+ # Suffixes of "^, #, @" symbols represent different input data formats for the
89
+ # same tool.
90
+ # 1) '^': data generate from XPlane.
91
+ # 2) '#': data is in gzip format.
92
+ # 3) '@': data generate from proto, or tracetable for streaming trace viewer.
93
+ # 4) no suffix: data is in json format, ready to feed to frontend.
94
+ TOOLS = {
95
+ 'xplane': 'xplane.pb',
96
+ 'hlo_proto': 'hlo_proto.pb',
97
+ }
98
+
99
+ ALL_HOSTS = 'ALL_HOSTS'
100
+
101
+ HostMetadata = TypedDict('HostMetadata', {'hostname': str})
102
+
103
+ _EXTENSION_TO_TOOL = {extension: tool for tool, extension in TOOLS.items()}
104
+
105
+ _FILENAME_RE = re.compile(r'(?:(.*)\.)?(' +
106
+ '|'.join(TOOLS.values()).replace('.', r'\.') + r')')
107
+
108
+
109
+ # Tools that can be generated from xplane end with ^.
110
+ XPLANE_TOOLS = [
111
+ 'trace_viewer', # non-streaming before TF 2.13
112
+ 'trace_viewer@', # streaming since TF 2.14
113
+ 'overview_page',
114
+ 'input_pipeline_analyzer',
115
+ 'framework_op_stats',
116
+ 'kernel_stats',
117
+ 'memory_profile',
118
+ 'pod_viewer',
119
+ 'op_profile',
120
+ 'hlo_stats',
121
+ 'roofline_model',
122
+ 'inference_profile',
123
+ 'memory_viewer',
124
+ 'graph_viewer',
125
+ 'megascale_stats',
126
+ ]
127
+
128
+ # XPlane generated tools that support all host mode.
129
+ XPLANE_TOOLS_ALL_HOSTS_SUPPORTED = frozenset([
130
+ 'input_pipeline_analyzer',
131
+ 'framework_op_stats',
132
+ 'kernel_stats',
133
+ 'overview_page',
134
+ 'pod_viewer',
135
+ 'megascale_stats',
136
+ ])
137
+
138
+ # XPlane generated tools that only support all host mode.
139
+ XPLANE_TOOLS_ALL_HOSTS_ONLY = frozenset(
140
+ ['overview_page', 'pod_viewer'])
141
+
142
+ # Rate limiter constants, the GCS quota defined below
143
+ # https://cloud.google.com/storage/quotas#rate-quotas.
144
+ # currently set to 1000 request per minute.
145
+ # TODO(kcai): The assumption on the average number of subdirs is not
146
+ # always true. If this is not sufficient, we can consider a token-based
147
+ # approach that counts the number of subdirs after calling iterdir.
148
+ MAX_GCS_REQUESTS = 1000
149
+ LIMIT_WINDOW_SECONDS = 60
150
+ AVERAGE_SUBDIR_NUMBER = 10
151
+
152
+
153
+ def use_xplane(tool: str) -> bool:
154
+ return tool in XPLANE_TOOLS
155
+
156
+
157
+ # HLO generated tools.
158
+ HLO_TOOLS = frozenset(['graph_viewer', 'memory_viewer'])
159
+
160
+
161
+ def use_hlo(tool: str) -> bool:
162
+ return tool in HLO_TOOLS
163
+
164
+
165
+ def make_filename(host: str, tool: str) -> str:
166
+ """Returns the name of the file containing data for the given host and tool.
167
+
168
+ Args:
169
+ host: Name of the host that produced the profile data, e.g., 'localhost'.
170
+ tool: Name of the tool, e.g., 'trace_viewer'.
171
+
172
+ Returns:
173
+ The host name concatenated with the tool-specific extension, e.g.,
174
+ 'localhost.trace'.
175
+ """
176
+ filename = str(host) + '.' if host else ''
177
+ if use_hlo(tool):
178
+ tool = 'hlo_proto'
179
+ elif use_xplane(tool):
180
+ tool = 'xplane'
181
+ return filename + TOOLS[tool]
182
+
183
+
184
+ def _parse_filename(filename: str) -> tuple[Optional[str], Optional[str]]:
185
+ """Returns the host and tool encoded in a filename in the run directory.
186
+
187
+ Args:
188
+ filename: Name of a file in the run directory. The name might encode a host
189
+ and tool, e.g., 'host.tracetable', 'host.domain.op_profile.json', or just
190
+ a tool, e.g., 'trace', 'tensorflow_stats.pb'.
191
+
192
+ Returns:
193
+ A tuple (host, tool) containing the names of the host and tool, e.g.,
194
+ ('localhost', 'trace_viewer'). Either of the tuple's components can be None.
195
+ """
196
+ m = _FILENAME_RE.fullmatch(filename)
197
+ if m is None:
198
+ return filename, None
199
+ return m.group(1), _EXTENSION_TO_TOOL[m.group(2)]
200
+
201
+
202
+ def _get_hosts(filenames: list[str]) -> set[str]:
203
+ """Parses a list of filenames and returns the set of hosts.
204
+
205
+ Args:
206
+ filenames: A list of filenames (just basenames, no directory).
207
+
208
+ Returns:
209
+ A set of host names encoded in the filenames.
210
+ """
211
+ hosts = set()
212
+ for name in filenames:
213
+ host, _ = _parse_filename(name)
214
+ if host:
215
+ hosts.add(host)
216
+ return hosts
217
+
218
+
219
+ def _get_tools(filenames: list[str], profile_run_dir: str) -> set[str]:
220
+ """Parses a list of filenames and returns the set of tools.
221
+
222
+ If xplane is present in the repository, add tools that can be generated by
223
+ xplane if we don't have a file for the tool.
224
+
225
+ Args:
226
+ filenames: A list of filenames.
227
+ profile_run_dir: The run directory of the profile.
228
+
229
+ Returns:
230
+ A set of tool names encoded in the filenames.
231
+ """
232
+ tools = set()
233
+ found = set()
234
+ xplane_filenames = []
235
+ for name in filenames:
236
+ _, tool = _parse_filename(name)
237
+ if tool == 'xplane':
238
+ xplane_filenames.append(os.path.join(profile_run_dir, name))
239
+ continue
240
+ elif tool == 'hlo_proto':
241
+ continue
242
+ elif tool:
243
+ tools.add(tool)
244
+ if tool[-1] in ('@'):
245
+ found.add(tool[:-1])
246
+ else:
247
+ found.add(tool)
248
+ # profile_run_dir might be empty, like in cloud AI use case.
249
+ if not profile_run_dir:
250
+ if xplane_filenames:
251
+ for item in XPLANE_TOOLS:
252
+ if item[:-1] not in found:
253
+ tools.add(item)
254
+ else:
255
+ try:
256
+ if xplane_filenames:
257
+ return set(convert.xspace_to_tool_names(xplane_filenames))
258
+ except AttributeError:
259
+ logger.warning('XPlane converters are available after Tensorflow 2.4')
260
+ return tools
261
+
262
+
263
+ def respond(
264
+ body: Any,
265
+ content_type: str,
266
+ code: int = 200,
267
+ content_encoding: Optional[tuple[str, str]] = None,
268
+ ) -> wrappers.Response:
269
+ """Create a Werkzeug response, handling JSON serialization and CSP.
270
+
271
+ Args:
272
+ body: For JSON responses, a JSON-serializable object; otherwise, a raw
273
+ `bytes` string or Unicode `str` (which will be encoded as UTF-8).
274
+ content_type: Response content-type (`str`); use `application/json` to
275
+ automatically serialize structures.
276
+ code: HTTP status code (`int`).
277
+ content_encoding: Response Content-Encoding header ('str'); e.g. 'gzip'. If
278
+ the content type is not set, The data would be compressed and the content
279
+ encoding would be set to gzip.
280
+
281
+ Returns:
282
+ A `werkzeug.wrappers.Response` object.
283
+ """
284
+ if content_type == 'application/json' and isinstance(
285
+ body, (dict, list, set, tuple)):
286
+ body = json.dumps(body, sort_keys=True)
287
+ if not isinstance(body, bytes):
288
+ body = body.encode('utf-8')
289
+ csp_parts = {
290
+ 'default-src': ["'self'"],
291
+ 'script-src': [
292
+ "'self'",
293
+ "'unsafe-eval'",
294
+ "'unsafe-inline'",
295
+ 'https://www.gstatic.com',
296
+ ],
297
+ 'object-src': ["'none'"],
298
+ 'style-src': [
299
+ "'self'",
300
+ "'unsafe-inline'",
301
+ 'https://fonts.googleapis.com',
302
+ 'https://www.gstatic.com',
303
+ ],
304
+ 'font-src': [
305
+ "'self'",
306
+ 'https://fonts.googleapis.com',
307
+ 'https://fonts.gstatic.com',
308
+ 'data:',
309
+ ],
310
+ 'connect-src': [
311
+ "'self'",
312
+ 'data:',
313
+ 'www.gstatic.com',
314
+ ],
315
+ 'img-src': [
316
+ "'self'",
317
+ 'blob:',
318
+ 'data:',
319
+ ],
320
+ 'script-src-elem': [
321
+ "'self'",
322
+ "'unsafe-inline'",
323
+ # Remember to restrict on integrity when importing from jsdelivr
324
+ # Whitelist this domain to support hlo_graph_dumper html format
325
+ 'https://cdn.jsdelivr.net/npm/',
326
+ 'https://www.gstatic.com',
327
+ ],
328
+ }
329
+ csp = ';'.join((' '.join([k] + v) for (k, v) in csp_parts.items()))
330
+ headers = [
331
+ ('Content-Security-Policy', csp),
332
+ ('X-Content-Type-Options', 'nosniff'),
333
+ ]
334
+ if content_encoding:
335
+ headers.append(('Content-Encoding', content_encoding))
336
+ else:
337
+ headers.append(('Content-Encoding', 'gzip'))
338
+ body = gzip.compress(body)
339
+ return wrappers.Response(
340
+ body, content_type=content_type, status=code, headers=headers
341
+ )
342
+
343
+
344
+ def _plugin_assets(
345
+ session_dir: str, runs: list[str], plugin_name: str
346
+ ) -> dict[str, list[str]]:
347
+ result = {}
348
+ for run in runs:
349
+ run_path = _tb_run_directory(session_dir, run)
350
+ assets = plugin_asset_util.ListAssets(run_path, plugin_name)
351
+ result[run] = assets
352
+ return result
353
+
354
+
355
+ def _tb_run_directory(session_dir: str, run: str) -> str:
356
+ """Returns the TensorBoard run directory for a TensorBoard run name.
357
+
358
+ This helper returns the TensorBoard-level run directory (the one that would)
359
+ contain tfevents files) for a given TensorBoard run name (aka the relative
360
+ path from the session_dir root to this directory). For the root run '.'
361
+ this is the bare session_dir path; for all other runs this is the
362
+ session_dir joined with the run name.
363
+
364
+ Args:
365
+ session_dir: the TensorBoard log directory root path
366
+ run: the TensorBoard run name, e.g. '.' or 'train'
367
+
368
+ Returns:
369
+ The TensorBoard run directory path, e.g. my/session_dir or
370
+ my/session_dir/train.
371
+ """
372
+ return session_dir if run == '.' else os.path.join(session_dir, run)
373
+
374
+
375
+ def filenames_to_hosts(filenames: list[str], tool: str) -> list[str]:
376
+ """Convert a list of filenames to a list of host names given a tool.
377
+
378
+ Args:
379
+ filenames: A list of filenames.
380
+ tool: A string representing the profiling tool.
381
+
382
+ Returns:
383
+ A list of hostnames.
384
+ """
385
+ hosts = _get_hosts(filenames)
386
+ if len(hosts) > 1:
387
+ if tool in XPLANE_TOOLS_ALL_HOSTS_ONLY:
388
+ hosts = [ALL_HOSTS]
389
+ elif tool in XPLANE_TOOLS_ALL_HOSTS_SUPPORTED:
390
+ hosts.add(ALL_HOSTS)
391
+ return sorted(hosts)
392
+
393
+
394
+ def _get_bool_arg(
395
+ args: Mapping[str, Any], arg_name: str, default: bool
396
+ ) -> bool:
397
+ """Gets a boolean argument from a request.
398
+
399
+ Args:
400
+ args: The werkzeug request arguments.
401
+ arg_name: The name of the argument.
402
+ default: The default value if the argument is not present.
403
+
404
+ Returns:
405
+ The boolean value of the argument.
406
+ """
407
+ arg_str = args.get(arg_name)
408
+ if arg_str is None:
409
+ return default
410
+ return arg_str.lower() == 'true'
411
+
412
+
413
+ class ToolsCache:
414
+ """Caches the list of tools for a profile run based on file content hashes or mtimes.
415
+
416
+ Attributes:
417
+ CACHE_FILE_NAME: The name of the cache file.
418
+ CACHE_VERSION: The version of the cache format.
419
+ """
420
+
421
+ CACHE_FILE_NAME = '.cached_tools.json'
422
+ CACHE_VERSION = 1
423
+
424
+ def __init__(self, profile_run_dir: epath.Path):
425
+ """Initializes the ToolsCache.
426
+
427
+ Args:
428
+ profile_run_dir: The directory containing the profile run data.
429
+ """
430
+ self._profile_run_dir = profile_run_dir
431
+ self._cache_file = self._profile_run_dir / self.CACHE_FILE_NAME
432
+ logger.info('ToolsCache initialized for %s', self._cache_file)
433
+
434
+ def _get_local_file_identifier(self, file_path_str: str) -> Optional[str]:
435
+ """Gets a string identifier for a local file.
436
+
437
+ The identifier is a combination of the file's last modification time (mtime)
438
+ and size, in the format "{mtime}-{size}".
439
+
440
+ Args:
441
+ file_path_str: The absolute path to the local file.
442
+
443
+ Returns:
444
+ A string identifier, or None if the file is not found or an error occurs.
445
+ """
446
+ try:
447
+ stat_result = os.stat(file_path_str)
448
+ return f'{int(stat_result.st_mtime)}-{stat_result.st_size}'
449
+ except FileNotFoundError:
450
+ logger.warning('Local file not found: %s', file_path_str)
451
+ return None
452
+ except OSError as e:
453
+ logger.error(
454
+ 'OSError getting stat for local file %s: %r', file_path_str, e
455
+ )
456
+ return None
457
+
458
+ def _get_gcs_file_hash(self, file_path_str: str) -> Optional[str]:
459
+ """Gets the MD5 hash for a GCS file.
460
+
461
+ Args:
462
+ file_path_str: The GCS path (e.g., "gs://bucket/object").
463
+
464
+ Returns:
465
+ The MD5 hash string, or None if the file is not found or an error occurs.
466
+ """
467
+ try:
468
+ fs = core.get_fs_token_paths(file_path_str)[0]
469
+ info = fs.info(file_path_str)
470
+ md5_hash = info.get('md5Hash')
471
+
472
+ if not isinstance(md5_hash, str):
473
+ logger.warning(
474
+ 'Could not find a valid md5Hash string in info for %s: %s',
475
+ file_path_str,
476
+ info,
477
+ )
478
+ return None
479
+
480
+ return md5_hash
481
+
482
+ except FileNotFoundError:
483
+ logger.warning('GCS path not found: %s', file_path_str)
484
+ return None
485
+ except IndexError:
486
+ logger.error('Could not get filesystem for GCS path: %s', file_path_str)
487
+ return None
488
+ except Exception as e: # pylint: disable=broad-exception-caught
489
+ logger.exception(
490
+ 'Unexpected error getting hash for GCS path %s: %r', file_path_str, e
491
+ )
492
+ return None
493
+
494
+ def get_file_identifier(self, file_path_str: str) -> Optional[str]:
495
+ """Gets a string identifier for a file.
496
+
497
+ For GCS files, this is the MD5 hash.
498
+ For local files, this is a string combining mtime and size.
499
+
500
+ Args:
501
+ file_path_str: The full path to the file (local or GCS).
502
+
503
+ Returns:
504
+ A string identifier, or None if an error occurs.
505
+ """
506
+ if file_path_str.startswith('gs://'):
507
+ return self._get_gcs_file_hash(file_path_str)
508
+ else:
509
+ return self._get_local_file_identifier(file_path_str)
510
+
511
+ def _get_current_xplane_file_states(self) -> Optional[Dict[str, str]]:
512
+ """Gets the current state of XPlane files in the profile run directory.
513
+
514
+ Returns:
515
+ A dictionary mapping filename to a string identifier (hash or mtime-size),
516
+ or None if any file state cannot be determined.
517
+ """
518
+ try:
519
+ file_identifiers = {}
520
+ for xplane_file in self._profile_run_dir.glob(f"*.{TOOLS['xplane']}"):
521
+ file_id = self.get_file_identifier(str(xplane_file))
522
+ if file_id is None:
523
+ logger.warning(
524
+ 'Could not get identifier for %s, cache will be invalidated.',
525
+ xplane_file,
526
+ )
527
+ return None
528
+ file_identifiers[xplane_file.name] = file_id
529
+ return file_identifiers
530
+ except OSError as e:
531
+ logger.warning('Could not glob files in %s: %r', self._profile_run_dir, e)
532
+ return None
533
+
534
+ def load(self) -> Optional[List[str]]:
535
+ """Loads the cached list of tools if the cache is valid.
536
+
537
+ The cache is valid if the cache file exists, the version matches, and
538
+ the file states (hashes/mtimes) of the XPlane files have not changed.
539
+
540
+ Returns:
541
+ A list of tool names if the cache is valid, otherwise None.
542
+ """
543
+ try:
544
+ with self._cache_file.open('r') as f:
545
+ cached_data = json.load(f)
546
+ except (OSError, json.JSONDecodeError) as e:
547
+ logger.warning(
548
+ 'Error reading or decoding cache file %s: %r, invalidating.',
549
+ self._cache_file,
550
+ e,
551
+ )
552
+ self.invalidate()
553
+ return None
554
+
555
+ if cached_data.get('version') != self.CACHE_VERSION:
556
+ logger.info(
557
+ 'ToolsCache invalid: version mismatch, expected %s, got %s.'
558
+ ' Invalidating %s',
559
+ self.CACHE_VERSION,
560
+ cached_data.get('version'),
561
+ self._cache_file,
562
+ )
563
+ self.invalidate()
564
+ return None
565
+
566
+ current_files = self._get_current_xplane_file_states()
567
+ if current_files is None:
568
+ logger.info(
569
+ 'ToolsCache invalid: could not determine current file states.'
570
+ ' Invalidating %s',
571
+ self._cache_file,
572
+ )
573
+ self.invalidate()
574
+ return None
575
+
576
+ if cached_data.get('files') != current_files:
577
+ logger.info(
578
+ 'ToolsCache invalid: file states differ. Invalidating %s',
579
+ self._cache_file,
580
+ )
581
+ self.invalidate()
582
+ return None
583
+
584
+ logger.info('ToolsCache hit: %s', self._cache_file)
585
+ return cached_data.get('tools')
586
+
587
+ def save(self, tools: Sequence[str]) -> None:
588
+ """Saves the list of tools and the current file states to the cache file.
589
+
590
+ Args:
591
+ tools: The list of tool names to cache.
592
+ """
593
+ current_files_for_cache = self._get_current_xplane_file_states()
594
+ if current_files_for_cache is None:
595
+ logger.warning(
596
+ 'ToolsCache not saved: could not get file states %s', self._cache_file
597
+ )
598
+ return
599
+
600
+ new_cache_data = {
601
+ 'version': self.CACHE_VERSION,
602
+ 'files': current_files_for_cache,
603
+ 'tools': tools,
604
+ }
605
+ try:
606
+ with self._cache_file.open('w') as f:
607
+ json.dump(new_cache_data, f, sort_keys=True, indent=2)
608
+ logger.info('ToolsCache saved: %s', self._cache_file)
609
+ except (OSError, TypeError) as e:
610
+ logger.error('Error writing cache file %s: %r', self._cache_file, e)
611
+
612
+ def invalidate(self) -> None:
613
+ """Deletes the cache file, forcing regeneration on the next load."""
614
+ try:
615
+ self._cache_file.unlink()
616
+ logger.info('ToolsCache invalidated: %s', self._cache_file)
617
+ except FileNotFoundError:
618
+ pass
619
+ except OSError as e:
620
+ logger.error('Error removing cache file %s: %r', self._cache_file, e)
621
+
622
+
623
+ class _TfProfiler:
624
+ """A helper class to encapsulate all TensorFlow-dependent profiler logic."""
625
+
626
+ def __init__(self, tf_module):
627
+ if not tf_module:
628
+ raise ImportError('TensorFlow module is not available.')
629
+ self.tf = tf_module
630
+
631
+ def _get_worker_list(self, cluster_resolver) -> str:
632
+ """Parses TPU workers list from the cluster resolver."""
633
+ cluster_spec = cluster_resolver.cluster_spec()
634
+ task_indices = cluster_spec.task_indices('worker')
635
+ worker_list = [
636
+ cluster_spec.task_address('worker', i).replace(':8470', ':8466')
637
+ for i in task_indices
638
+ ]
639
+ return ','.join(worker_list)
640
+
641
+ def resolve_tpu_name(
642
+ self, tpu_name: str, worker_list: str
643
+ ) -> tuple[str, str, str]:
644
+ """Resolves a TPU name to its master IP, service address, and worker list.
645
+
646
+ Args:
647
+ tpu_name: The name of the TPU to resolve.
648
+ worker_list: A comma-separated list of worker addresses.
649
+
650
+ Returns:
651
+ A tuple containing (service_addr, worker_list, master_ip).
652
+ """
653
+ try:
654
+ resolver = self.tf.distribute.cluster_resolver.TPUClusterResolver(
655
+ tpu_name
656
+ )
657
+ master_grpc_addr = resolver.get_master()
658
+ except RuntimeError as err:
659
+ # Propagate error to be handled by the caller.
660
+ raise RuntimeError(
661
+ f'Error initializing TPUClusterResolver: {err}'
662
+ ) from err
663
+ except (ValueError, TypeError) as e:
664
+ # Handle cases where the TPU name is invalid.
665
+ raise ValueError(f'No TPU found with the name: {tpu_name}') from e
666
+
667
+ if not worker_list:
668
+ worker_list = self._get_worker_list(resolver)
669
+
670
+ # TPU cluster resolver always returns port 8470. Replace it with 8466
671
+ # on which profiler service is running.
672
+ master_ip = master_grpc_addr.replace('grpc://', '').replace(':8470', '')
673
+ service_addr = f'{master_ip}:8466'
674
+ return service_addr, worker_list, master_ip
675
+
676
+
677
+ class ProfilePlugin(base_plugin.TBPlugin):
678
+ """Profile Plugin for TensorBoard."""
679
+ plugin_name = PLUGIN_NAME
680
+
681
+ def __init__(self, context):
682
+ """Constructs a profiler plugin for TensorBoard.
683
+
684
+ This plugin adds handlers for performance-related frontends.
685
+ Args:
686
+ context: A base_plugin.TBContext instance.
687
+ """
688
+ self.logdir = context.logdir
689
+ self.data_provider = context.data_provider
690
+ self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel
691
+ self.hide_capture_profile_button = getattr(
692
+ context, 'hide_capture_profile_button', False
693
+ )
694
+
695
+ # Whether the plugin is active. This is an expensive computation, so we
696
+ # compute this asynchronously and cache positive results indefinitely.
697
+ self._is_active = False
698
+ # Lock to ensure at most one thread computes _is_active at a time.
699
+ self._is_active_lock = threading.Lock()
700
+ # Cache to map profile run name to corresponding tensorboard dir name
701
+ self._run_to_profile_run_dir = {}
702
+ self._tf_profiler = _TfProfiler(tf) if tf else None
703
+
704
+ def is_active(self) -> bool:
705
+ """Whether this plugin is active and has any profile data to show.
706
+
707
+ Returns:
708
+ Whether any run has profile data.
709
+ """
710
+ if not self._is_active:
711
+ self._is_active = any(self.generate_runs())
712
+ return self._is_active
713
+
714
+ def _does_tool_support_multi_hosts_processing(self, tool: str) -> bool:
715
+ """Returns true if the tool supports multi-hosts processing."""
716
+ return tool == 'trace_viewer@' or tool == 'trace_viewer'
717
+
718
+ def get_plugin_apps(
719
+ self,
720
+ ) -> dict[str, Callable[[wrappers.Request], wrappers.Response]]:
721
+ return {
722
+ BASE_ROUTE: self.default_handler,
723
+ INDEX_JS_ROUTE: self.static_file_route,
724
+ INDEX_HTML_ROUTE: self.static_file_route,
725
+ BUNDLE_JS_ROUTE: self.static_file_route,
726
+ STYLES_CSS_ROUTE: self.static_file_route,
727
+ MATERIALICONS_WOFF2_ROUTE: self.static_file_route,
728
+ TRACE_VIEWER_INDEX_HTML_ROUTE: self.static_file_route,
729
+ TRACE_VIEWER_INDEX_JS_ROUTE: self.static_file_route,
730
+ ZONE_JS_ROUTE: self.static_file_route,
731
+ RUNS_ROUTE: self.runs_route,
732
+ RUN_TOOLS_ROUTE: self.run_tools_route,
733
+ HOSTS_ROUTE: self.hosts_route,
734
+ DATA_ROUTE: self.data_route,
735
+ HLO_MODULE_LIST_ROUTE: self.hlo_module_list_route,
736
+ CAPTURE_ROUTE: self.capture_route,
737
+ LOCAL_ROUTE: self.default_handler,
738
+ CONFIG_ROUTE: self.config_route,
739
+ }
740
+
741
+ # pytype: disable=wrong-arg-types
742
+ @wrappers.Request.application
743
+ def default_handler(self, _: wrappers.Request) -> wrappers.Response:
744
+ contents = self._read_static_file_impl('index.html')
745
+ return respond(contents, 'text/html')
746
+
747
+ # pytype: disable=wrong-arg-types
748
+ @wrappers.Request.application
749
+ def config_route(self, request: wrappers.Request) -> wrappers.Response:
750
+ # pytype: enable=wrong-arg-types
751
+ """Returns UI configuration details."""
752
+ logger.info('config_route: %s', self.logdir)
753
+ config_data = {
754
+ 'hideCaptureProfileButton': self.hide_capture_profile_button,
755
+ }
756
+ return respond(config_data, 'application/json')
757
+
758
+ def frontend_metadata(self):
759
+ return base_plugin.FrontendMetadata(es_module_path='/index.js')
760
+
761
+ def _read_static_file_impl(self, filename: str) -> bytes:
762
+ """Reads contents from a filename.
763
+
764
+ Args:
765
+ filename (str): Name of the file.
766
+
767
+ Returns:
768
+ Contents of the file.
769
+ Raises:
770
+ IOError: File could not be read or found.
771
+ """
772
+ filepath = os.path.join(os.path.dirname(__file__), 'static', filename)
773
+
774
+ try:
775
+ with open(filepath, 'rb') as infile:
776
+ contents = infile.read()
777
+ except IOError as io_error:
778
+ raise io_error
779
+ return contents
780
+
781
+ # pytype: disable=wrong-arg-types
782
+ @wrappers.Request.application
783
+ def static_file_route(self, request: wrappers.Request) -> wrappers.Response:
784
+ # pytype: enable=wrong-arg-types
785
+ filename = os.path.basename(request.path)
786
+ extention = os.path.splitext(filename)[1]
787
+ if extention == '.html':
788
+ mimetype = 'text/html'
789
+ elif extention == '.css':
790
+ mimetype = 'text/css'
791
+ elif extention == '.js':
792
+ mimetype = 'application/javascript'
793
+ else:
794
+ mimetype = 'application/octet-stream'
795
+ try:
796
+ contents = self._read_static_file_impl(filename)
797
+ except IOError:
798
+ return respond('Fail to read the files.', 'text/plain', code=404)
799
+ return respond(contents, mimetype)
800
+
801
+ # pytype: disable=wrong-arg-types
802
+ @wrappers.Request.application
803
+ def runs_route(self, request: wrappers.Request) -> wrappers.Response:
804
+ # pytype: enable=wrong-arg-types
805
+ runs = self.runs_imp(request)
806
+ return respond(runs, 'application/json')
807
+
808
+ def _run_map_from_request(
809
+ self, request: Optional[wrappers.Request] = None
810
+ ) -> Optional[dict[str, str]]:
811
+ """Returns a map of run names to session directories from the request.
812
+
813
+ Args:
814
+ request: Optional; werkzeug request used for grabbing session_path and
815
+ run_path arguments.
816
+ """
817
+ session_path_arg = request.args.get('session_path') if request else None
818
+ run_path_arg = (
819
+ request.args.get('run_path')
820
+ if request and not session_path_arg
821
+ else None
822
+ )
823
+ run_map = None
824
+ if session_path_arg:
825
+ session_path = epath.Path(session_path_arg)
826
+ run_name = session_path.name
827
+ run_map = {}
828
+ if session_path.is_dir() and any(session_path.glob('*.xplane.pb')):
829
+ run_map[run_name] = str(session_path)
830
+ elif run_path_arg:
831
+ run_path = epath.Path(run_path_arg)
832
+ run_map = {}
833
+ for session in run_path.iterdir():
834
+ if session.is_dir() and any(session.glob('*.xplane.pb')):
835
+ run_map[session.name] = str(session)
836
+ return run_map
837
+
838
+ def _run_dir(
839
+ self, run: str, request: Optional[wrappers.Request] = None
840
+ ) -> Optional[str]:
841
+ """Helper that maps a frontend run name to a profile "run" directory.
842
+
843
+ The frontend run name consists of the TensorBoard run name (aka the relative
844
+ path from the logdir root to the directory containing the data) path-joined
845
+ to the Profile plugin's "run" concept (which is a subdirectory of the
846
+ plugins/profile directory representing an individual run of the tool), with
847
+ the special case that TensorBoard run is the logdir root (which is the run
848
+ named '.') then only the Profile plugin "run" name is used, for backwards
849
+ compatibility.
850
+
851
+ Args:
852
+ run: the frontend run name, as described above, e.g. train/run1.
853
+ request: Optional; werkzeug request used for grabbing session_path and
854
+ run_path arguments.
855
+
856
+ Returns:
857
+ The resolved directory path, e.g. /logdir/train/plugins/profile/run1.
858
+
859
+ Raises:
860
+ ValueError: If the run is not found in the run map.
861
+ RuntimeError: If the run directory is not found.
862
+ """
863
+ run_map = self._run_map_from_request(request)
864
+ if run_map is not None:
865
+ if run in run_map:
866
+ return run_map[run]
867
+ else:
868
+ raise ValueError(f'Run {run} not found in run map: {run_map}')
869
+
870
+ if run in self._run_to_profile_run_dir:
871
+ return self._run_to_profile_run_dir[run]
872
+
873
+ if not self.logdir:
874
+ raise RuntimeError(
875
+ 'No matching run directory for run %s. Logdir is empty.' % run
876
+ )
877
+ tb_run_name, profile_run_name = os.path.split(run.rstrip(os.sep))
878
+ if not tb_run_name:
879
+ tb_run_name = '.'
880
+ tb_run_directory = _tb_run_directory(self.logdir, tb_run_name)
881
+ if not epath.Path(tb_run_directory).is_dir():
882
+ raise RuntimeError('No matching run directory for run %s' % run)
883
+ plugin_directory = plugin_asset_util.PluginDirectory(
884
+ tb_run_directory, PLUGIN_NAME
885
+ )
886
+ return os.path.join(plugin_directory, profile_run_name)
887
+
888
+ def runs_imp(self, request: Optional[wrappers.Request] = None) -> list[str]:
889
+ """Returns a list all runs for the profile plugin.
890
+
891
+ Args:
892
+ request: Optional; werkzeug request used for grabbing ctx and experiment
893
+ id for other host implementations
894
+ """
895
+ run_map = self._run_map_from_request(request)
896
+ if run_map is not None:
897
+ runs = run_map.keys()
898
+ else:
899
+ runs = self.generate_runs()
900
+ return sorted(runs, reverse=True)
901
+
902
+ # pytype: disable=wrong-arg-types
903
+ @wrappers.Request.application
904
+ def run_tools_route(self, request: wrappers.Request) -> wrappers.Response:
905
+ # pytype: enable=wrong-arg-types
906
+ run = request.args.get('run')
907
+ run_tools = self.run_tools_imp(run, request)
908
+ return respond(run_tools, 'application/json')
909
+
910
+ def run_tools_imp(
911
+ self, run, request: Optional[wrappers.Request] = None
912
+ ) -> list[str]:
913
+ """Returns a list of tools given a single run.
914
+
915
+ Args:
916
+ run: the frontend run name, item is list returned by runs_imp
917
+ request: Optional; werkzeug request used for grabbing ctx and experiment
918
+ id for other host implementations
919
+ """
920
+ run_dir = self._run_dir(run, request)
921
+ return list(self.generate_tools_of_run(run, run_dir))
922
+
923
+ def _run_host_impl(
924
+ self, run: str, run_dir: str, tool: str
925
+ ) -> List[HostMetadata]:
926
+ if not run_dir:
927
+ logger.warning('Cannot find asset directory for: %s', run)
928
+ return []
929
+ tool_pattern = '*.xplane.pb'
930
+ filenames = []
931
+ try:
932
+ path = epath.Path(run_dir)
933
+ filenames = path.glob(tool_pattern)
934
+ except OSError as e:
935
+ logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e)
936
+ filenames = [os.fspath(os.path.basename(f)) for f in filenames]
937
+
938
+ return [{'hostname': host} for host in filenames_to_hosts(filenames, tool)]
939
+
940
+ def host_impl(
941
+ self, run: str, tool: str, request: Optional[wrappers.Request] = None
942
+ ) -> List[HostMetadata]:
943
+ """Returns available hosts and their metadata for the run and tool in the log directory.
944
+
945
+ In the plugin log directory, each directory contains profile data for a
946
+ single run (identified by the directory name), and files in the run
947
+ directory contains data for different tools and hosts. The file that
948
+ contains profile for a specific tool "x" will have extension TOOLS["x"].
949
+
950
+ Example:
951
+ log/
952
+ run1/
953
+ plugins/
954
+ profile/
955
+ host1.trace
956
+ host2.trace
957
+ module1.hlo_proto.pb
958
+ module2.hlo_proto.pb
959
+ run2/
960
+ plugins/
961
+ profile/
962
+ host1.trace
963
+ host2.trace
964
+
965
+ Args:
966
+ run: the frontend run name, e.g., 'run1' or 'run2' for the example above.
967
+ tool: the requested tool, e.g., 'trace_viewer' for the example above.
968
+ request: Optional; werkzeug request used for grabbing ctx and experiment
969
+ id for other host implementations
970
+
971
+ Returns:
972
+ A list of host names, e.g.:
973
+ host_impl(run1, trace_viewer) --> [{"hostname": "host1"}, {"hostname":
974
+ "host2"}]
975
+ host_impl(run1, memory_viewer) --> [{"hostname": "module1"},
976
+ {"hostname":
977
+ "module2"}]
978
+ """
979
+ run_dir = self._run_dir(run, request)
980
+ return self._run_host_impl(run, run_dir, tool)
981
+
982
+ # pytype: disable=wrong-arg-types
983
+ @wrappers.Request.application
984
+ def hosts_route(self, request: wrappers.Request) -> wrappers.Response:
985
+ # pytype: enable=wrong-arg-types
986
+ run = request.args.get('run')
987
+ tool = request.args.get('tag')
988
+ hosts = self.host_impl(run, tool, request)
989
+ return respond(hosts, 'application/json')
990
+
991
+ # pytype: disable=wrong-arg-types
992
+ @wrappers.Request.application
993
+ def hlo_module_list_route(
994
+ self, request: wrappers.Request
995
+ ) -> wrappers.Response:
996
+ module_names_str = self.hlo_module_list_impl(request)
997
+ return respond(module_names_str, 'text/plain')
998
+
999
+ def _get_valid_hosts(
1000
+ self, run_dir: str, run: str, tool: str, hosts_param: str, host: str
1001
+ ) -> tuple[List[str], List[epath.Path]]:
1002
+ """Retrieves and validates the hosts and asset paths for a run and tool.
1003
+
1004
+ Args:
1005
+ run_dir: The run directory.
1006
+ run: The frontend run name.
1007
+ tool: The requested tool.
1008
+ hosts_param: Comma-separated list of selected hosts.
1009
+ host: The single host parameter.
1010
+
1011
+ Returns:
1012
+ A tuple containing (selected_hosts, asset_paths).
1013
+
1014
+ Raises:
1015
+ FileNotFoundError: If a required xplane file for the specified host(s)
1016
+ is not found.
1017
+ IOError: If there is an error reading asset directories.
1018
+ """
1019
+ asset_paths = []
1020
+ selected_hosts = []
1021
+ all_xplane_files = {} # Map host to path
1022
+
1023
+ # Find all available xplane files for the run and map them by host.
1024
+ file_pattern = make_filename('*', 'xplane')
1025
+ try:
1026
+ path = epath.Path(run_dir)
1027
+ for xplane_path in path.glob(file_pattern):
1028
+ host_name, _ = _parse_filename(xplane_path.name)
1029
+ if host_name:
1030
+ all_xplane_files[host_name] = xplane_path
1031
+ except OSError as e:
1032
+ logger.warning('Cannot read asset directory: %s, OpError %r', run_dir, e)
1033
+ raise IOError(
1034
+ 'Cannot read asset directory: %s, OpError %r' % (run_dir, e)
1035
+ ) from e
1036
+
1037
+ if hosts_param and self._does_tool_support_multi_hosts_processing(tool):
1038
+ selected_hosts = hosts_param.split(',')
1039
+ for selected_host in selected_hosts:
1040
+ if selected_host in all_xplane_files:
1041
+ asset_paths.append(all_xplane_files[selected_host])
1042
+ else:
1043
+ raise FileNotFoundError(
1044
+ 'No xplane file found for host: %s in run: %s'
1045
+ % (selected_host, run)
1046
+ )
1047
+ logger.info('Inside trace_viewer@, asset_paths: %s')
1048
+ elif host == ALL_HOSTS:
1049
+ asset_paths = list(all_xplane_files.values())
1050
+ selected_hosts = list(all_xplane_files.keys())
1051
+ elif host and host in all_xplane_files:
1052
+ selected_hosts = [host]
1053
+ asset_paths = [all_xplane_files[host]]
1054
+ elif host:
1055
+ logger.warning('No xplane file found for host: %s in run: %s', host, run)
1056
+ if host not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
1057
+ raise FileNotFoundError(
1058
+ 'No xplane file found for host: %s in run: %s' % (host, run)
1059
+ )
1060
+ elif not host and not hosts_param and len(all_xplane_files) == 1:
1061
+ selected_hosts = list(all_xplane_files.keys())
1062
+ asset_paths = list(all_xplane_files.values())
1063
+
1064
+ if not asset_paths:
1065
+ logger.warning(
1066
+ 'No matching asset paths found for run %s, tool %s, host(s) %s / %s',
1067
+ run,
1068
+ tool,
1069
+ hosts_param,
1070
+ host,
1071
+ )
1072
+ if not host and tool not in XPLANE_TOOLS_ALL_HOSTS_ONLY:
1073
+ raise FileNotFoundError(
1074
+ 'Host must be specified for tool %s in run %s' % (tool, run)
1075
+ )
1076
+
1077
+ return selected_hosts, asset_paths
1078
+
1079
+ def data_impl(
1080
+ self, request: wrappers.Request
1081
+ ) -> tuple[Optional[str], str, Optional[str]]:
1082
+ """Retrieves and processes the tool data for a run and a host.
1083
+
1084
+ Args:
1085
+ request: XMLHttpRequest
1086
+
1087
+ Returns:
1088
+ A string that can be served to the frontend tool or None if tool,
1089
+ run or host is invalid.
1090
+
1091
+ Raises:
1092
+ FileNotFoundError: If a required xplane file for the specified host(s)
1093
+ is not found.
1094
+ IOError: If there is an error reading asset directories.
1095
+ AttributeError: If there is an error during xplane to tool data conversion
1096
+ ValueError: If xplane conversion fails due to invalid data.
1097
+ """
1098
+ run = request.args.get('run')
1099
+ tool = request.args.get('tag')
1100
+ hosts_param = request.args.get('hosts')
1101
+ host = request.args.get('host')
1102
+ module_name = request.args.get('module_name')
1103
+ tqx = request.args.get('tqx')
1104
+ use_saved_result = _get_bool_arg(request.args, 'use_saved_result', True)
1105
+ full_dma = _get_bool_arg(request.args, 'full_dma', False)
1106
+ run_dir = self._run_dir(run, request)
1107
+
1108
+ # Check if the cache file exists and if the cache file version is less
1109
+ # than the current plugin version, clear the cache.
1110
+ try:
1111
+ if epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).exists():
1112
+ with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
1113
+ 'r'
1114
+ ) as f:
1115
+ cache_version = f.read().strip()
1116
+ if cache_version < version.__version__:
1117
+ use_saved_result = False
1118
+ else:
1119
+ use_saved_result = False
1120
+ except OSError as e:
1121
+ logger.warning('Cannot read cache version file: %r', e)
1122
+ use_saved_result = False
1123
+
1124
+ graph_viewer_options = self._get_graph_viewer_options(request)
1125
+ # Host param is used by HLO tools to identify the module.
1126
+ params = {
1127
+ 'graph_viewer_options': graph_viewer_options,
1128
+ 'tqx': tqx,
1129
+ 'host': host,
1130
+ 'module_name': module_name,
1131
+ 'use_saved_result': use_saved_result,
1132
+ }
1133
+ if request.args.get('group_by'):
1134
+ params['group_by'] = request.args.get('group_by')
1135
+ content_type = 'application/json'
1136
+
1137
+ if tool not in TOOLS and not use_xplane(tool):
1138
+ return None, content_type, None
1139
+ if tool == 'memory_viewer' and request.args.get(
1140
+ 'view_memory_allocation_timeline'
1141
+ ):
1142
+ params['view_memory_allocation_timeline'] = True
1143
+
1144
+ params['memory_space'] = request.args.get('memory_space', '0')
1145
+
1146
+ if tool == 'trace_viewer@':
1147
+ options = {}
1148
+ options['resolution'] = request.args.get('resolution', 8000)
1149
+ options['full_dma'] = full_dma
1150
+ if request.args.get('start_time_ms') is not None:
1151
+ options['start_time_ms'] = request.args.get('start_time_ms')
1152
+ if request.args.get('end_time_ms') is not None:
1153
+ options['end_time_ms'] = request.args.get('end_time_ms')
1154
+ if request.args.get('event_name') is not None:
1155
+ options['event_name'] = request.args.get('event_name')
1156
+ if request.args.get('duration_ms') is not None:
1157
+ options['duration_ms'] = request.args.get('duration_ms')
1158
+ if request.args.get('unique_id') is not None:
1159
+ options['unique_id'] = request.args.get('unique_id')
1160
+ if request.args.get('search_prefix') is not None:
1161
+ options['search_prefix'] = request.args.get('search_prefix')
1162
+ params['trace_viewer_options'] = options
1163
+
1164
+ _, content_encoding = None, None
1165
+ if use_xplane(tool):
1166
+ selected_hosts, asset_paths = self._get_valid_hosts(
1167
+ run_dir, run, tool, hosts_param, host
1168
+ )
1169
+ if not asset_paths:
1170
+ return None, content_type, None
1171
+
1172
+ params['hosts'] = selected_hosts
1173
+ try:
1174
+ data, content_type = convert.xspace_to_tool_data(
1175
+ asset_paths, tool, params)
1176
+ except AttributeError as e:
1177
+ logger.warning('Error generating analysis results due to %r', e)
1178
+ raise AttributeError(
1179
+ 'Error generating analysis results due to %r' % e
1180
+ ) from e
1181
+ except ValueError as e:
1182
+ logger.warning('XPlane convert to tool data failed as %r', e)
1183
+ raise e
1184
+ except FileNotFoundError as e:
1185
+ logger.warning('XPlane convert to tool data failed as %r', e)
1186
+ raise e
1187
+
1188
+ # Write cache version file if use_saved_result is False.
1189
+ if not use_saved_result:
1190
+ try:
1191
+ with epath.Path(os.path.join(run_dir, CACHE_VERSION_FILE)).open(
1192
+ 'w'
1193
+ ) as f:
1194
+ f.write(version.__version__)
1195
+ except OSError as e:
1196
+ logger.warning('Cannot write cache version file: %r', e)
1197
+
1198
+ return data, content_type, content_encoding
1199
+
1200
+ logger.info('%s does not use xplane', tool)
1201
+ return None, content_type, None
1202
+
1203
+ def hlo_module_list_impl(
1204
+ self, request: wrappers.Request
1205
+ ) -> str:
1206
+ """Returns a string of HLO module names concatenated by comma for the given run."""
1207
+ run = request.args.get('run')
1208
+ run_dir = self._run_dir(run, request)
1209
+ module_list = []
1210
+ if not run_dir:
1211
+ logger.warning('Cannot find asset directory for: %s', run)
1212
+ return ''
1213
+ tool_pattern = '*.hlo_proto.pb'
1214
+ filenames = []
1215
+ try:
1216
+ path = epath.Path(run_dir)
1217
+ filenames = path.glob(tool_pattern)
1218
+ except OSError as e:
1219
+ logger.warning('Cannot read asset directory: %s, OpError %r', run_dir, e)
1220
+ filenames = [os.fspath(os.path.basename(f)) for f in filenames]
1221
+ for filename in filenames:
1222
+ module_name, _ = _parse_filename(filename)
1223
+ if module_name:
1224
+ module_list.append(module_name)
1225
+ module_names_str = ','.join(module_list)
1226
+ return module_names_str
1227
+
1228
+ # pytype: disable=wrong-arg-types
1229
+ @wrappers.Request.application
1230
+ def data_route(self, request: wrappers.Request) -> wrappers.Response:
1231
+ # pytype: enable=wrong-arg-types
1232
+ # params
1233
+ # request: XMLHTTPRequest.
1234
+ try:
1235
+ data, content_type, content_encoding = self.data_impl(request)
1236
+ if data is None:
1237
+ return respond('No Data', 'text/plain', code=404)
1238
+ return respond(data, content_type, content_encoding=content_encoding)
1239
+ # Data fetch error handler
1240
+ except TimeoutError as e:
1241
+ return respond(str(e), 'text/plain', code=500)
1242
+ except AttributeError as e:
1243
+ return respond(str(e), 'text/plain', code=500)
1244
+ except ValueError as e:
1245
+ return respond(str(e), 'text/plain', code=500)
1246
+ except FileNotFoundError as e:
1247
+ return respond(str(e), 'text/plain', code=500)
1248
+ except IOError as e:
1249
+ return respond(str(e), 'text/plain', code=500)
1250
+
1251
+ # pytype: disable=wrong-arg-types
1252
+ @wrappers.Request.application
1253
+ def capture_route(self, request: wrappers.Request) -> wrappers.Response:
1254
+ # pytype: enable=wrong-arg-types
1255
+ return self.capture_route_impl(request)
1256
+
1257
+ def capture_route_impl(self, request: wrappers.Request) -> wrappers.Response:
1258
+ """Runs the client trace for capturing profiling information."""
1259
+ service_addr = request.args.get('service_addr')
1260
+ duration = int(request.args.get('duration', '1000'))
1261
+ is_tpu_name = request.args.get('is_tpu_name') == 'true'
1262
+ worker_list = request.args.get('worker_list')
1263
+ num_tracing_attempts = int(request.args.get('num_retry', '0')) + 1
1264
+ options = {
1265
+ 'host_tracer_level': int(request.args.get('host_tracer_level', '2')),
1266
+ 'device_tracer_level': int(
1267
+ request.args.get('device_tracer_level', '1')
1268
+ ),
1269
+ 'python_tracer_level': int(
1270
+ request.args.get('python_tracer_level', '0')
1271
+ ),
1272
+ 'delay_ms': int(request.args.get('delay', '0')),
1273
+ }
1274
+
1275
+ if is_tpu_name:
1276
+ if not self._tf_profiler:
1277
+ return respond(
1278
+ {
1279
+ 'error': (
1280
+ 'TensorFlow is not installed, but is required to use TPU'
1281
+ ' names.'
1282
+ )
1283
+ },
1284
+ 'application/json',
1285
+ code=500,
1286
+ )
1287
+ try:
1288
+ # Delegate to the helper class for all TF-related logic.
1289
+ service_addr, worker_list, master_ip = (
1290
+ self._tf_profiler.resolve_tpu_name(service_addr, worker_list or '')
1291
+ )
1292
+ self.master_tpu_unsecure_channel = master_ip
1293
+ except (RuntimeError, ValueError) as err:
1294
+ return respond({'error': str(err)}, 'application/json', code=500)
1295
+
1296
+ if not self.logdir:
1297
+ return respond(
1298
+ {'error': 'logdir is not set, abort capturing.'},
1299
+ 'application/json',
1300
+ code=500,
1301
+ )
1302
+ try:
1303
+ # The core trace call remains, now with cleanly resolved parameters.
1304
+ _pywrap_profiler_plugin.trace(
1305
+ service_addr.removeprefix('grpc://'),
1306
+ str(self.logdir),
1307
+ worker_list,
1308
+ True,
1309
+ duration,
1310
+ num_tracing_attempts,
1311
+ options,
1312
+ )
1313
+ return respond(
1314
+ {'result': 'Capture profile successfully. Please refresh.'},
1315
+ 'application/json',
1316
+ )
1317
+ except Exception as e: # pylint: disable=broad-except
1318
+ return respond({'error': str(e)}, 'application/json', code=500)
1319
+
1320
+ def _get_graph_viewer_options(
1321
+ self, request: wrappers.Request
1322
+ ) -> dict[str, Any]:
1323
+ node_name = request.args.get('node_name')
1324
+ module_name = request.args.get('module_name')
1325
+ graph_width_str = request.args.get('graph_width') or ''
1326
+ graph_width = int(graph_width_str) if graph_width_str.isdigit() else 3
1327
+ show_metadata = int(request.args.get('show_metadata') == 'true')
1328
+ merge_fusion = int(request.args.get('merge_fusion') == 'true')
1329
+ return {
1330
+ 'node_name': node_name,
1331
+ 'module_name': module_name,
1332
+ 'graph_width': graph_width,
1333
+ 'show_metadata': show_metadata,
1334
+ 'merge_fusion': merge_fusion,
1335
+ 'format': request.args.get('format'),
1336
+ 'type': request.args.get('type')
1337
+ }
1338
+
1339
+ def generate_runs(self) -> Iterator[str]:
1340
+ """Generator for a list of runs.
1341
+
1342
+ The "run name" here is a "frontend run name" - see _tb_run_directory() for
1343
+ the definition of a "frontend run name" and how it maps to a directory of
1344
+ profile data for a specific profile "run". The profile plugin concept of
1345
+ "run" is different from the normal TensorBoard run; each run in this case
1346
+ represents a single instance of profile data collection, more similar to a
1347
+ "step" of data in typical TensorBoard semantics. These runs reside in
1348
+ subdirectories of the plugins/profile directory within any regular
1349
+ TensorBoard run directory or within the session_dir root directory
1350
+ itself (even if it contains no tfevents file and would thus not be
1351
+ considered a normal TensorBoard run, for backwards compatibility).
1352
+
1353
+ `generate_runs` will get all runs first, and get tools list from
1354
+ `generate_tools_of_run` for a single run due to expensive processing for
1355
+ xspace data to parse the tools.
1356
+ Example:
1357
+ logs/
1358
+ plugins/
1359
+ profile/
1360
+ run1/
1361
+ hostA.trace
1362
+ train/
1363
+ events.out.tfevents.foo
1364
+ plugins/
1365
+ profile/
1366
+ run1/
1367
+ hostA.trace
1368
+ hostB.trace
1369
+ run2/
1370
+ hostA.trace
1371
+ validation/
1372
+ events.out.tfevents.foo
1373
+ plugins/
1374
+ profile/
1375
+ run1/
1376
+ hostA.trace
1377
+ new_job/
1378
+ tensorboard/
1379
+ plugins/
1380
+ profile/
1381
+ run1/
1382
+ hostA.xplane.pb
1383
+ Yields:
1384
+ A sequence of string that are "frontend run names".
1385
+ For the above example, this would be:
1386
+ "run1", "train/run1", "train/run2", "validation/run1",
1387
+ "new_job/tensorboard/run1"
1388
+ """
1389
+ if not self.logdir:
1390
+ return
1391
+
1392
+ # Ensure that we check the root logdir and all subdirectories.
1393
+ # Note that we check if logdir is a directory to handle case where
1394
+ # it's actually a multipart directory spec, which this plugin does not
1395
+ # support.
1396
+ #
1397
+ # This change still enforce the requirement that the subdirectories must
1398
+ # end with plugins/profile directory, as enforced by TensorBoard.
1399
+ logdir_path = epath.Path(self.logdir)
1400
+ schemeless_logdir = str(logdir_path)
1401
+ if '://' in schemeless_logdir:
1402
+ schemeless_logdir = schemeless_logdir.split('://', 1)[1]
1403
+ tb_runs = {'.'}
1404
+
1405
+ if logdir_path.is_dir():
1406
+ try:
1407
+ fs = etils.epath.backend.fsspec_backend.fs(self.logdir)
1408
+ for path_str in fs.glob(os.path.join(self.logdir, '**', PLUGIN_NAME)):
1409
+ path = epath.Path(path_str)
1410
+ if fs.isdir(path) and path.parent.name == TB_NAME:
1411
+ tb_run_dir = path.parent.parent
1412
+ tb_run = tb_run_dir.relative_to(schemeless_logdir)
1413
+ tb_runs.add(str(tb_run))
1414
+ except ValueError:
1415
+ # gcsfs not available, fall back to legacy path walk.
1416
+ for cur_dir, _, _ in logdir_path.walk():
1417
+ if (cur_dir.name == PLUGIN_NAME and cur_dir.parent.name == TB_NAME):
1418
+ tb_run_dir = cur_dir.parent.parent
1419
+ tb_run = tb_run_dir.relative_to(logdir_path)
1420
+ tb_runs.add(str(tb_run))
1421
+ tb_run_names_to_dirs = {
1422
+ run: _tb_run_directory(self.logdir, run) for run in tb_runs
1423
+ }
1424
+ plugin_assets = _plugin_assets(
1425
+ self.logdir, list(tb_run_names_to_dirs), PLUGIN_NAME
1426
+ )
1427
+ visited_runs = set()
1428
+ for tb_run_name, profile_runs in six.iteritems(plugin_assets):
1429
+ tb_run_dir = tb_run_names_to_dirs[tb_run_name]
1430
+ tb_plugin_dir = plugin_asset_util.PluginDirectory(tb_run_dir, PLUGIN_NAME)
1431
+
1432
+ for profile_run in profile_runs:
1433
+ # Remove trailing separator; some filesystem implementations emit this.
1434
+ profile_run = profile_run.rstrip(os.sep)
1435
+ if tb_run_name == '.':
1436
+ frontend_run = profile_run
1437
+ else:
1438
+ frontend_run = str(epath.Path(tb_run_name) / profile_run)
1439
+ profile_run_dir = str(epath.Path(tb_plugin_dir) / profile_run)
1440
+ if epath.Path(profile_run_dir).is_dir():
1441
+ self._run_to_profile_run_dir[frontend_run] = profile_run_dir
1442
+ if frontend_run not in visited_runs:
1443
+ visited_runs.add(frontend_run)
1444
+ yield frontend_run
1445
+
1446
+ def generate_tools_of_run(self, run: str, run_dir: str) -> Iterator[str]:
1447
+ """Generate a list of tools given a certain run."""
1448
+ if not run_dir:
1449
+ logger.warning('Cannot find asset directory for: %s', run)
1450
+ return
1451
+ profile_run_dir = epath.Path(run_dir)
1452
+ cache = ToolsCache(profile_run_dir)
1453
+
1454
+ cached_tools = cache.load()
1455
+
1456
+ if cached_tools is not None:
1457
+ for tool in cached_tools:
1458
+ yield tool
1459
+ return
1460
+
1461
+ # Cache is invalid or doesn't exist, regenerate
1462
+ tools = []
1463
+ try:
1464
+ all_filenames = [f.name for f in profile_run_dir.iterdir()]
1465
+ except OSError as e:
1466
+ logger.warning(
1467
+ 'Cannot read asset directory: %s, Error %r', profile_run_dir, e
1468
+ )
1469
+ return tools
1470
+
1471
+ if all_filenames:
1472
+ tools = self._get_active_tools(all_filenames, str(profile_run_dir))
1473
+ cache.save(tools)
1474
+
1475
+ for tool in tools:
1476
+ yield tool
1477
+
1478
+ def _get_active_tools(self, filenames, profile_run_dir=''):
1479
+ """Get a list of tools available given the filenames created by profiler.
1480
+
1481
+ Args:
1482
+ filenames: List of strings that represent filenames
1483
+ profile_run_dir: The run directory of the profile.
1484
+
1485
+ Returns:
1486
+ A list of strings representing the available tools
1487
+ """
1488
+ tool_sort_order = [
1489
+ 'overview_page',
1490
+ 'trace_viewer',
1491
+ 'trace_viewer@',
1492
+ 'graph_viewer',
1493
+ 'op_profile',
1494
+ 'hlo_op_profile',
1495
+ 'input_pipeline_analyzer',
1496
+ 'input_pipeline',
1497
+ 'kernel_stats',
1498
+ 'memory_profile',
1499
+ 'memory_viewer',
1500
+ 'roofline_model',
1501
+ 'perf_counters',
1502
+ 'pod_viewer',
1503
+ 'framework_op_stats',
1504
+ 'tensorflow_stats', # Legacy name for framework_op_stats
1505
+ 'hlo_op_stats',
1506
+ 'hlo_stats', # Legacy name for hlo_op_stats
1507
+ 'inference_profile',
1508
+ 'megascale_stats',
1509
+ ]
1510
+ tools = _get_tools(filenames, profile_run_dir)
1511
+ if 'trace_viewer@' in tools:
1512
+ # streaming trace viewer always override normal trace viewer.
1513
+ # the trailing '@' is to inform tf-profile-dashboard.html and
1514
+ # tf-trace-viewer.html that stream trace viewer should be used.
1515
+ tools.discard('trace_viewer')
1516
+
1517
+ sorted_tools = [t for t in tool_sort_order if t in tools]
1518
+ remaining_tools = tools.difference(sorted_tools)
1519
+ sorted_tools.extend(sorted(remaining_tools))
1520
+
1521
+ return sorted_tools