xprof-nightly 2.22.3a20251208__cp311-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xprof/__init__.py +22 -0
- xprof/convert/_pywrap_profiler_plugin.so +0 -0
- xprof/convert/csv_writer.py +87 -0
- xprof/convert/raw_to_tool_data.py +232 -0
- xprof/convert/trace_events_json.py +105 -0
- xprof/integration_tests/tf_mnist.py +100 -0
- xprof/integration_tests/tf_profiler_session.py +40 -0
- xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py +183 -0
- xprof/profile_plugin.py +1521 -0
- xprof/profile_plugin_loader.py +82 -0
- xprof/protobuf/dcn_collective_info_pb2.py +44 -0
- xprof/protobuf/dcn_slack_analysis_pb2.py +42 -0
- xprof/protobuf/diagnostics_pb2.py +36 -0
- xprof/protobuf/event_time_fraction_analyzer_pb2.py +42 -0
- xprof/protobuf/hardware_types_pb2.py +40 -0
- xprof/protobuf/hlo_stats_pb2.py +39 -0
- xprof/protobuf/inference_stats_pb2.py +86 -0
- xprof/protobuf/input_pipeline_pb2.py +52 -0
- xprof/protobuf/kernel_stats_pb2.py +38 -0
- xprof/protobuf/memory_profile_pb2.py +54 -0
- xprof/protobuf/memory_viewer_preprocess_pb2.py +49 -0
- xprof/protobuf/op_metrics_pb2.py +65 -0
- xprof/protobuf/op_profile_pb2.py +49 -0
- xprof/protobuf/op_stats_pb2.py +71 -0
- xprof/protobuf/overview_page_pb2.py +64 -0
- xprof/protobuf/pod_stats_pb2.py +45 -0
- xprof/protobuf/pod_viewer_pb2.py +61 -0
- xprof/protobuf/power_metrics_pb2.py +38 -0
- xprof/protobuf/roofline_model_pb2.py +42 -0
- xprof/protobuf/smart_suggestion_pb2.py +38 -0
- xprof/protobuf/source_info_pb2.py +36 -0
- xprof/protobuf/source_stats_pb2.py +48 -0
- xprof/protobuf/steps_db_pb2.py +76 -0
- xprof/protobuf/task_pb2.py +37 -0
- xprof/protobuf/tf_data_stats_pb2.py +72 -0
- xprof/protobuf/tf_function_pb2.py +52 -0
- xprof/protobuf/tf_stats_pb2.py +40 -0
- xprof/protobuf/tfstreamz_pb2.py +40 -0
- xprof/protobuf/topology_pb2.py +50 -0
- xprof/protobuf/tpu_input_pipeline_pb2.py +43 -0
- xprof/protobuf/trace_events_old_pb2.py +54 -0
- xprof/protobuf/trace_events_pb2.py +64 -0
- xprof/protobuf/trace_events_raw_pb2.py +45 -0
- xprof/protobuf/trace_filter_config_pb2.py +40 -0
- xprof/server.py +319 -0
- xprof/standalone/base_plugin.py +52 -0
- xprof/standalone/context.py +22 -0
- xprof/standalone/data_provider.py +32 -0
- xprof/standalone/plugin_asset_util.py +131 -0
- xprof/standalone/plugin_event_multiplexer.py +185 -0
- xprof/standalone/tensorboard_shim.py +31 -0
- xprof/static/bundle.js +130500 -0
- xprof/static/index.html +64 -0
- xprof/static/index.js +3 -0
- xprof/static/materialicons.woff2 +0 -0
- xprof/static/styles.css +1 -0
- xprof/static/trace_viewer_index.html +3929 -0
- xprof/static/trace_viewer_index.js +15906 -0
- xprof/static/zone.js +3558 -0
- xprof/version.py +17 -0
- xprof_nightly-2.22.3a20251208.dist-info/METADATA +301 -0
- xprof_nightly-2.22.3a20251208.dist-info/RECORD +65 -0
- xprof_nightly-2.22.3a20251208.dist-info/WHEEL +5 -0
- xprof_nightly-2.22.3a20251208.dist-info/entry_points.txt +5 -0
- xprof_nightly-2.22.3a20251208.dist-info/top_level.txt +1 -0
xprof/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 The XProf Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Entry point for the TensorBoard plugin package for XProf.
|
|
16
|
+
|
|
17
|
+
Public submodules:
|
|
18
|
+
profile_plugin: The TensorBoard plugin integration.
|
|
19
|
+
profile_plugin_loader: TensorBoard's entrypoint for the plugin.
|
|
20
|
+
server: Standalone server entrypoint.
|
|
21
|
+
version: The version of the plugin.
|
|
22
|
+
"""
|
|
Binary file
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""Converts data between different formats."""
|
|
17
|
+
|
|
18
|
+
import csv
|
|
19
|
+
import io
|
|
20
|
+
import json
|
|
21
|
+
from typing import List, Optional
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def json_to_csv(
|
|
25
|
+
json_string: str,
|
|
26
|
+
columns_order: Optional[List[str]] = None,
|
|
27
|
+
separator: str = ",",
|
|
28
|
+
) -> str:
|
|
29
|
+
"""Converts a JSON string to a CSV string.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
json_string: The JSON string to convert.
|
|
33
|
+
columns_order: Optional. Specifies the order of columns in the output table.
|
|
34
|
+
Specify a list of all column IDs in the order in which you want the table
|
|
35
|
+
created. Note that you must list all column IDs in this parameter, if you
|
|
36
|
+
use it.
|
|
37
|
+
separator: Optional. The separator to use between the values.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A CSV string representing the data.
|
|
41
|
+
Example result:
|
|
42
|
+
'a','b','c'
|
|
43
|
+
1,'z',2
|
|
44
|
+
3,'w',''
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
ValueError: The json_string is not a valid json or is not a single object.
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
data = json.loads(json_string)
|
|
51
|
+
except json.JSONDecodeError as e:
|
|
52
|
+
raise ValueError(f"Invalid JSON string: {e}") from e
|
|
53
|
+
|
|
54
|
+
if not isinstance(data, dict):
|
|
55
|
+
raise ValueError("JSON data must be a single object")
|
|
56
|
+
|
|
57
|
+
headers = list(data.keys())
|
|
58
|
+
if columns_order is None:
|
|
59
|
+
columns_order = headers
|
|
60
|
+
|
|
61
|
+
if not all(col in headers for col in columns_order):
|
|
62
|
+
raise ValueError("columns_order must be a list of all column IDs")
|
|
63
|
+
|
|
64
|
+
if not columns_order:
|
|
65
|
+
return ""
|
|
66
|
+
|
|
67
|
+
csv_buffer = io.StringIO(newline="")
|
|
68
|
+
writer = csv.writer(csv_buffer, delimiter=separator, lineterminator="\n")
|
|
69
|
+
|
|
70
|
+
# Write header
|
|
71
|
+
writer.writerow([col for col in columns_order])
|
|
72
|
+
|
|
73
|
+
# Write row
|
|
74
|
+
cells_list = []
|
|
75
|
+
for col in columns_order:
|
|
76
|
+
value = ""
|
|
77
|
+
if col in data:
|
|
78
|
+
value = data.get(col)
|
|
79
|
+
if value is not None:
|
|
80
|
+
if isinstance(value, bool):
|
|
81
|
+
value = "true" if value else "false"
|
|
82
|
+
else:
|
|
83
|
+
value = str(value)
|
|
84
|
+
cells_list.append(value)
|
|
85
|
+
writer.writerow(cells_list)
|
|
86
|
+
|
|
87
|
+
return csv_buffer.getvalue()
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""For conversion of raw files to tool data.
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
data = xspace_to_tool_data(xplane, tool, params)
|
|
19
|
+
data = tool_proto_to_tool_data(tool_proto, tool, params)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import absolute_import
|
|
23
|
+
from __future__ import division
|
|
24
|
+
from __future__ import print_function
|
|
25
|
+
|
|
26
|
+
import logging
|
|
27
|
+
|
|
28
|
+
from xprof.convert import csv_writer
|
|
29
|
+
from xprof.convert import trace_events_json
|
|
30
|
+
from xprof.protobuf import trace_events_old_pb2
|
|
31
|
+
from xprof.convert import _pywrap_profiler_plugin
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger('tensorboard')
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def process_raw_trace(raw_trace):
|
|
38
|
+
"""Processes raw trace data and returns the UI data."""
|
|
39
|
+
trace = trace_events_old_pb2.Trace()
|
|
40
|
+
trace.ParseFromString(raw_trace)
|
|
41
|
+
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
|
|
45
|
+
params):
|
|
46
|
+
"""Helper function for getting an XSpace tool from a bytes string.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
xspace_byte_list: A list of byte strings read from a XSpace proto file.
|
|
50
|
+
filenames: Names of the read files.
|
|
51
|
+
tool: A string of tool name.
|
|
52
|
+
params: user input parameters.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Returns a string of tool data.
|
|
56
|
+
"""
|
|
57
|
+
# pylint:disable=dangerous-default-value
|
|
58
|
+
def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
|
|
59
|
+
return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
|
|
60
|
+
xspace_arg, filenames, tool_arg, params)
|
|
61
|
+
# pylint:enable=dangerous-default-value
|
|
62
|
+
|
|
63
|
+
return xspace_to_tool_data(xspace_byte_list, tool, params,
|
|
64
|
+
xspace_wrapper_func)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def xspace_to_tool_names(xspace_paths):
|
|
68
|
+
"""Converts XSpace to all the available tool names.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
xspace_paths: A list of XSpace paths.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Returns a list of tool names.
|
|
75
|
+
"""
|
|
76
|
+
raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
|
|
77
|
+
xspace_paths, 'tool_names')
|
|
78
|
+
if success:
|
|
79
|
+
return [tool for tool in raw_data.decode().split(',')]
|
|
80
|
+
return []
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def xspace_to_tool_data(
|
|
84
|
+
xspace_paths,
|
|
85
|
+
tool,
|
|
86
|
+
params,
|
|
87
|
+
xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
|
|
88
|
+
"""Converts XSpace to tool data string.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
xspace_paths: A list of XSpace paths.
|
|
92
|
+
tool: A string of tool name.
|
|
93
|
+
params: user input parameters.
|
|
94
|
+
xspace_wrapper_func: A callable that takes a list of strings and a tool and
|
|
95
|
+
returns the raw data. If failed, raw data contains the error message.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Returns a string of tool data and the content type for the response.
|
|
99
|
+
"""
|
|
100
|
+
if (tool[-1] == '^'):
|
|
101
|
+
old_tool = tool
|
|
102
|
+
tool = tool[:-1] # Remove the trailing '^'
|
|
103
|
+
logger.warning(
|
|
104
|
+
'Received old tool format: %s; mapped to new format: %s', old_tool, tool
|
|
105
|
+
)
|
|
106
|
+
data = None
|
|
107
|
+
content_type = 'application/json'
|
|
108
|
+
# tqx: gViz output format
|
|
109
|
+
tqx = params.get('tqx', '')
|
|
110
|
+
options = {}
|
|
111
|
+
options['use_saved_result'] = params.get('use_saved_result', True)
|
|
112
|
+
if tool == 'trace_viewer':
|
|
113
|
+
# Trace viewer handles one host at a time.
|
|
114
|
+
assert len(xspace_paths) == 1
|
|
115
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
116
|
+
if success:
|
|
117
|
+
data = process_raw_trace(raw_data)
|
|
118
|
+
elif tool == 'trace_viewer@':
|
|
119
|
+
options = params.get('trace_viewer_options', {})
|
|
120
|
+
options['use_saved_result'] = params.get('use_saved_result', True)
|
|
121
|
+
options['hosts'] = params.get('hosts', [])
|
|
122
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
123
|
+
if success:
|
|
124
|
+
data = raw_data
|
|
125
|
+
elif tool == 'overview_page':
|
|
126
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
127
|
+
if success:
|
|
128
|
+
data = json_data
|
|
129
|
+
elif tool == 'input_pipeline_analyzer':
|
|
130
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
131
|
+
if success:
|
|
132
|
+
data = json_data
|
|
133
|
+
elif tool == 'framework_op_stats':
|
|
134
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
135
|
+
if success:
|
|
136
|
+
if tqx == 'out:csv':
|
|
137
|
+
data = csv_writer.json_to_csv(json_data)
|
|
138
|
+
else:
|
|
139
|
+
data = json_data
|
|
140
|
+
# Try legacy tool name: Handle backward compatibility with lower TF version
|
|
141
|
+
else:
|
|
142
|
+
# TODO(b/419013992): Remove this tool completely as it has been deprecated
|
|
143
|
+
legacy_tool = 'tensorflow_stats'
|
|
144
|
+
json_data, success = xspace_wrapper_func(
|
|
145
|
+
xspace_paths, legacy_tool, options
|
|
146
|
+
)
|
|
147
|
+
if success:
|
|
148
|
+
if tqx == 'out:csv':
|
|
149
|
+
data = csv_writer.json_to_csv(json_data)
|
|
150
|
+
else:
|
|
151
|
+
data = json_data
|
|
152
|
+
elif tool == 'kernel_stats':
|
|
153
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
154
|
+
if success:
|
|
155
|
+
if tqx == 'out:csv':
|
|
156
|
+
data = csv_writer.json_to_csv(json_data)
|
|
157
|
+
else:
|
|
158
|
+
data = json_data
|
|
159
|
+
elif tool == 'memory_profile':
|
|
160
|
+
# Memory profile handles one host at a time.
|
|
161
|
+
assert len(xspace_paths) == 1
|
|
162
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
163
|
+
if success:
|
|
164
|
+
data = raw_data
|
|
165
|
+
elif tool == 'pod_viewer':
|
|
166
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
167
|
+
if success:
|
|
168
|
+
data = raw_data
|
|
169
|
+
elif tool == 'op_profile':
|
|
170
|
+
options['group_by'] = params.get('group_by', 'program')
|
|
171
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
172
|
+
if success:
|
|
173
|
+
data = raw_data
|
|
174
|
+
elif tool == 'hlo_op_profile':
|
|
175
|
+
options['group_by'] = params.get('group_by', 'program')
|
|
176
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
177
|
+
if success:
|
|
178
|
+
data = raw_data
|
|
179
|
+
elif tool == 'hlo_stats':
|
|
180
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
181
|
+
if success:
|
|
182
|
+
data = json_data
|
|
183
|
+
elif tool == 'roofline_model':
|
|
184
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
185
|
+
if success:
|
|
186
|
+
data = json_data
|
|
187
|
+
elif tool == 'graph_viewer':
|
|
188
|
+
download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
|
|
189
|
+
graph_html_type = 'graph'
|
|
190
|
+
options = params.get('graph_viewer_options', {})
|
|
191
|
+
options['use_saved_result'] = params.get('use_saved_result', True)
|
|
192
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
193
|
+
if success:
|
|
194
|
+
data = raw_data
|
|
195
|
+
content_type = 'text/plain'
|
|
196
|
+
data_type = options.get('type', '')
|
|
197
|
+
if (data_type in download_hlo_types):
|
|
198
|
+
content_type = 'application/octet-stream'
|
|
199
|
+
if data_type == graph_html_type:
|
|
200
|
+
content_type = 'text/html'
|
|
201
|
+
else:
|
|
202
|
+
# TODO(tf-profiler) Handle errors for other tools as well,
|
|
203
|
+
# to pass along the error message to client
|
|
204
|
+
if isinstance(raw_data, bytes):
|
|
205
|
+
raw_data = raw_data.decode('utf-8')
|
|
206
|
+
raise ValueError(raw_data)
|
|
207
|
+
elif tool == 'memory_viewer':
|
|
208
|
+
view_memory_allocation_timeline = params.get(
|
|
209
|
+
'view_memory_allocation_timeline', False
|
|
210
|
+
)
|
|
211
|
+
options = {
|
|
212
|
+
'module_name': params.get('module_name'),
|
|
213
|
+
'view_memory_allocation_timeline': view_memory_allocation_timeline,
|
|
214
|
+
'memory_space': params.get('memory_space', ''),
|
|
215
|
+
}
|
|
216
|
+
raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
217
|
+
if success:
|
|
218
|
+
data = raw_data
|
|
219
|
+
if view_memory_allocation_timeline:
|
|
220
|
+
content_type = 'text/html'
|
|
221
|
+
elif tool == 'megascale_stats':
|
|
222
|
+
options = {'host_name': params.get('host')}
|
|
223
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
224
|
+
if success:
|
|
225
|
+
data = json_data
|
|
226
|
+
elif tool == 'inference_profile':
|
|
227
|
+
json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
|
|
228
|
+
if success:
|
|
229
|
+
data = json_data
|
|
230
|
+
else:
|
|
231
|
+
logger.warning('%s is not a known xplane tool', tool)
|
|
232
|
+
return data, content_type
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
"""Converts trace events to JSON format consumed by catapult trace viewer."""
|
|
17
|
+
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
from __future__ import division
|
|
20
|
+
from __future__ import print_function
|
|
21
|
+
|
|
22
|
+
import json
|
|
23
|
+
import six
|
|
24
|
+
|
|
25
|
+
# Values for type (ph) and s (scope) parameters in catapult trace format.
|
|
26
|
+
_TYPE_METADATA = 'M'
|
|
27
|
+
_TYPE_COMPLETE = 'X'
|
|
28
|
+
_TYPE_INSTANT = 'i'
|
|
29
|
+
_SCOPE_THREAD = 't'
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TraceEventsJsonStream(object):
|
|
33
|
+
"""A streaming trace file in the format expected by catapult trace viewer.
|
|
34
|
+
|
|
35
|
+
Iterating over this yields a sequence of string chunks, so it is suitable for
|
|
36
|
+
returning in a werkzeug Response.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, proto):
|
|
40
|
+
"""Create an iterable JSON stream over the supplied Trace.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
proto: a tensorboard.profile.Trace protobuf
|
|
44
|
+
"""
|
|
45
|
+
self._proto = proto
|
|
46
|
+
|
|
47
|
+
def _events(self):
|
|
48
|
+
"""Iterator over all catapult trace events, as python values."""
|
|
49
|
+
for did, device in sorted(six.iteritems(self._proto.devices)):
|
|
50
|
+
if device.name:
|
|
51
|
+
yield dict(
|
|
52
|
+
ph=_TYPE_METADATA,
|
|
53
|
+
pid=did,
|
|
54
|
+
name='process_name',
|
|
55
|
+
args=dict(name=device.name))
|
|
56
|
+
yield dict(
|
|
57
|
+
ph=_TYPE_METADATA,
|
|
58
|
+
pid=did,
|
|
59
|
+
name='process_sort_index',
|
|
60
|
+
args=dict(sort_index=did))
|
|
61
|
+
for rid, resource in sorted(six.iteritems(device.resources)):
|
|
62
|
+
if resource.name:
|
|
63
|
+
yield dict(
|
|
64
|
+
ph=_TYPE_METADATA,
|
|
65
|
+
pid=did,
|
|
66
|
+
tid=rid,
|
|
67
|
+
name='thread_name',
|
|
68
|
+
args=dict(name=resource.name))
|
|
69
|
+
yield dict(
|
|
70
|
+
ph=_TYPE_METADATA,
|
|
71
|
+
pid=did,
|
|
72
|
+
tid=rid,
|
|
73
|
+
name='thread_sort_index',
|
|
74
|
+
args=dict(sort_index=rid))
|
|
75
|
+
for event in self._proto.trace_events:
|
|
76
|
+
yield self._event(event)
|
|
77
|
+
|
|
78
|
+
def _event(self, event):
|
|
79
|
+
"""Converts a TraceEvent proto into a catapult trace event python value."""
|
|
80
|
+
result = dict(
|
|
81
|
+
pid=event.device_id,
|
|
82
|
+
tid=event.resource_id,
|
|
83
|
+
name=event.name,
|
|
84
|
+
ts=event.timestamp_ps / 1000000.0)
|
|
85
|
+
if event.duration_ps:
|
|
86
|
+
result['ph'] = _TYPE_COMPLETE
|
|
87
|
+
result['dur'] = event.duration_ps / 1000000.0
|
|
88
|
+
else:
|
|
89
|
+
result['ph'] = _TYPE_INSTANT
|
|
90
|
+
result['s'] = _SCOPE_THREAD
|
|
91
|
+
for key in dict(event.args):
|
|
92
|
+
if 'args' not in result:
|
|
93
|
+
result['args'] = {}
|
|
94
|
+
result['args'][key] = event.args[key]
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
def __iter__(self):
|
|
98
|
+
"""Returns an iterator of string chunks of a complete JSON document."""
|
|
99
|
+
yield '{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},\n'
|
|
100
|
+
yield '"traceEvents":[\n'
|
|
101
|
+
for event in self._events():
|
|
102
|
+
yield json.dumps(event, sort_keys=True)
|
|
103
|
+
yield ',\n'
|
|
104
|
+
# Add one fake event to avoid dealing with no-trailing-comma rule.
|
|
105
|
+
yield '{}]}\n'
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Tensorflow based mnist model and input datasets helper."""
|
|
16
|
+
|
|
17
|
+
_NUM_CLASSES = 10
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_input_datasets(tf, use_bfloat16=False):
|
|
21
|
+
"""Downloads the MNIST dataset and creates train and eval dataset objects.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
tf: version specific tensorflow (either tf1 or tf2).
|
|
25
|
+
use_bfloat16: Boolean to determine if input should be cast to bfloat16
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Train dataset, eval dataset, test dataset and input shape.
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
# input image dimensions
|
|
32
|
+
img_rows, img_cols = 28, 28
|
|
33
|
+
cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32
|
|
34
|
+
|
|
35
|
+
# the data, split between train and test sets
|
|
36
|
+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
|
37
|
+
|
|
38
|
+
if tf.keras.backend.image_data_format() == 'channels_first':
|
|
39
|
+
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
|
|
40
|
+
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
|
|
41
|
+
input_shape = (1, img_rows, img_cols)
|
|
42
|
+
else:
|
|
43
|
+
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
|
|
44
|
+
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
|
|
45
|
+
input_shape = (img_rows, img_cols, 1)
|
|
46
|
+
|
|
47
|
+
x_train = x_train.astype('float32')
|
|
48
|
+
x_test = x_test.astype('float32')
|
|
49
|
+
x_train /= 255
|
|
50
|
+
x_test /= 255
|
|
51
|
+
|
|
52
|
+
# convert class vectors to binary class matrices
|
|
53
|
+
y_train = tf.keras.utils.to_categorical(y_train, _NUM_CLASSES)
|
|
54
|
+
y_test = tf.keras.utils.to_categorical(y_test, _NUM_CLASSES)
|
|
55
|
+
|
|
56
|
+
# train dataset
|
|
57
|
+
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
|
58
|
+
train_ds = train_ds.repeat()
|
|
59
|
+
train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
|
|
60
|
+
train_ds = train_ds.batch(64, drop_remainder=True)
|
|
61
|
+
|
|
62
|
+
# eval dataset
|
|
63
|
+
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
|
64
|
+
eval_ds = eval_ds.repeat()
|
|
65
|
+
eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
|
|
66
|
+
eval_ds = eval_ds.batch(64, drop_remainder=True)
|
|
67
|
+
|
|
68
|
+
# test dataset
|
|
69
|
+
test_ds = tf.data.Dataset.from_tensor_slices((x_test))
|
|
70
|
+
test_ds = test_ds.repeat()
|
|
71
|
+
test_ds = test_ds.map(lambda x: tf.cast(x, cast_dtype))
|
|
72
|
+
test_ds = test_ds.batch(64, drop_remainder=True)
|
|
73
|
+
|
|
74
|
+
return train_ds, eval_ds, test_ds, input_shape
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_model(tf, input_shape):
|
|
78
|
+
"""Builds a Sequential CNN model to recognize MNIST digits.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
tf: version specific tensorflow (either tf1 or tf2).
|
|
82
|
+
input_shape: Shape of the input depending on the `image_data_format`.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a Keras model
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
# Define a CNN model to recognize MNIST digits.
|
|
89
|
+
model = tf.keras.models.Sequential()
|
|
90
|
+
model.add(
|
|
91
|
+
tf.keras.layers.Conv2D(
|
|
92
|
+
32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
|
|
93
|
+
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
|
|
94
|
+
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
|
|
95
|
+
model.add(tf.keras.layers.Dropout(0.25))
|
|
96
|
+
model.add(tf.keras.layers.Flatten())
|
|
97
|
+
model.add(tf.keras.layers.Dense(128, activation='relu'))
|
|
98
|
+
model.add(tf.keras.layers.Dropout(0.5))
|
|
99
|
+
model.add(tf.keras.layers.Dense(_NUM_CLASSES, activation='softmax'))
|
|
100
|
+
return model
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Programmatic functionality for tensorflow profile session."""
|
|
16
|
+
|
|
17
|
+
import types
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TensorflowProfilerSession:
|
|
21
|
+
"""context manager for Tensorflow programmatic profile session."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, tf: types.ModuleType, log_dir: str, python_tracer_level: int = 1
|
|
25
|
+
):
|
|
26
|
+
"""tf object is version dependent object."""
|
|
27
|
+
self.profiler = tf.profiler.experimental
|
|
28
|
+
self.log_dir = log_dir
|
|
29
|
+
self.python_tracer_level = python_tracer_level
|
|
30
|
+
|
|
31
|
+
def __enter__(self):
|
|
32
|
+
"""Starts profile session and serializes data to temp directory."""
|
|
33
|
+
options = self.profiler.ProfilerOptions(
|
|
34
|
+
host_tracer_level=2, python_tracer_level=self.python_tracer_level
|
|
35
|
+
)
|
|
36
|
+
self.profiler.start(self.log_dir, options=options)
|
|
37
|
+
|
|
38
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
39
|
+
"""Ends current profile session and verifies test expectations."""
|
|
40
|
+
self.profiler.stop()
|