xprof-nightly 2.22.3a20251208__cp311-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. xprof/__init__.py +22 -0
  2. xprof/convert/_pywrap_profiler_plugin.so +0 -0
  3. xprof/convert/csv_writer.py +87 -0
  4. xprof/convert/raw_to_tool_data.py +232 -0
  5. xprof/convert/trace_events_json.py +105 -0
  6. xprof/integration_tests/tf_mnist.py +100 -0
  7. xprof/integration_tests/tf_profiler_session.py +40 -0
  8. xprof/integration_tests/tpu/tensorflow/tpu_tf2_keras_test.py +183 -0
  9. xprof/profile_plugin.py +1521 -0
  10. xprof/profile_plugin_loader.py +82 -0
  11. xprof/protobuf/dcn_collective_info_pb2.py +44 -0
  12. xprof/protobuf/dcn_slack_analysis_pb2.py +42 -0
  13. xprof/protobuf/diagnostics_pb2.py +36 -0
  14. xprof/protobuf/event_time_fraction_analyzer_pb2.py +42 -0
  15. xprof/protobuf/hardware_types_pb2.py +40 -0
  16. xprof/protobuf/hlo_stats_pb2.py +39 -0
  17. xprof/protobuf/inference_stats_pb2.py +86 -0
  18. xprof/protobuf/input_pipeline_pb2.py +52 -0
  19. xprof/protobuf/kernel_stats_pb2.py +38 -0
  20. xprof/protobuf/memory_profile_pb2.py +54 -0
  21. xprof/protobuf/memory_viewer_preprocess_pb2.py +49 -0
  22. xprof/protobuf/op_metrics_pb2.py +65 -0
  23. xprof/protobuf/op_profile_pb2.py +49 -0
  24. xprof/protobuf/op_stats_pb2.py +71 -0
  25. xprof/protobuf/overview_page_pb2.py +64 -0
  26. xprof/protobuf/pod_stats_pb2.py +45 -0
  27. xprof/protobuf/pod_viewer_pb2.py +61 -0
  28. xprof/protobuf/power_metrics_pb2.py +38 -0
  29. xprof/protobuf/roofline_model_pb2.py +42 -0
  30. xprof/protobuf/smart_suggestion_pb2.py +38 -0
  31. xprof/protobuf/source_info_pb2.py +36 -0
  32. xprof/protobuf/source_stats_pb2.py +48 -0
  33. xprof/protobuf/steps_db_pb2.py +76 -0
  34. xprof/protobuf/task_pb2.py +37 -0
  35. xprof/protobuf/tf_data_stats_pb2.py +72 -0
  36. xprof/protobuf/tf_function_pb2.py +52 -0
  37. xprof/protobuf/tf_stats_pb2.py +40 -0
  38. xprof/protobuf/tfstreamz_pb2.py +40 -0
  39. xprof/protobuf/topology_pb2.py +50 -0
  40. xprof/protobuf/tpu_input_pipeline_pb2.py +43 -0
  41. xprof/protobuf/trace_events_old_pb2.py +54 -0
  42. xprof/protobuf/trace_events_pb2.py +64 -0
  43. xprof/protobuf/trace_events_raw_pb2.py +45 -0
  44. xprof/protobuf/trace_filter_config_pb2.py +40 -0
  45. xprof/server.py +319 -0
  46. xprof/standalone/base_plugin.py +52 -0
  47. xprof/standalone/context.py +22 -0
  48. xprof/standalone/data_provider.py +32 -0
  49. xprof/standalone/plugin_asset_util.py +131 -0
  50. xprof/standalone/plugin_event_multiplexer.py +185 -0
  51. xprof/standalone/tensorboard_shim.py +31 -0
  52. xprof/static/bundle.js +130500 -0
  53. xprof/static/index.html +64 -0
  54. xprof/static/index.js +3 -0
  55. xprof/static/materialicons.woff2 +0 -0
  56. xprof/static/styles.css +1 -0
  57. xprof/static/trace_viewer_index.html +3929 -0
  58. xprof/static/trace_viewer_index.js +15906 -0
  59. xprof/static/zone.js +3558 -0
  60. xprof/version.py +17 -0
  61. xprof_nightly-2.22.3a20251208.dist-info/METADATA +301 -0
  62. xprof_nightly-2.22.3a20251208.dist-info/RECORD +65 -0
  63. xprof_nightly-2.22.3a20251208.dist-info/WHEEL +5 -0
  64. xprof_nightly-2.22.3a20251208.dist-info/entry_points.txt +5 -0
  65. xprof_nightly-2.22.3a20251208.dist-info/top_level.txt +1 -0
xprof/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ # Copyright 2025 The XProf Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Entry point for the TensorBoard plugin package for XProf.
16
+
17
+ Public submodules:
18
+ profile_plugin: The TensorBoard plugin integration.
19
+ profile_plugin_loader: TensorBoard's entrypoint for the plugin.
20
+ server: Standalone server entrypoint.
21
+ version: The version of the plugin.
22
+ """
Binary file
@@ -0,0 +1,87 @@
1
+ # Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Converts data between different formats."""
17
+
18
+ import csv
19
+ import io
20
+ import json
21
+ from typing import List, Optional
22
+
23
+
24
+ def json_to_csv(
25
+ json_string: str,
26
+ columns_order: Optional[List[str]] = None,
27
+ separator: str = ",",
28
+ ) -> str:
29
+ """Converts a JSON string to a CSV string.
30
+
31
+ Args:
32
+ json_string: The JSON string to convert.
33
+ columns_order: Optional. Specifies the order of columns in the output table.
34
+ Specify a list of all column IDs in the order in which you want the table
35
+ created. Note that you must list all column IDs in this parameter, if you
36
+ use it.
37
+ separator: Optional. The separator to use between the values.
38
+
39
+ Returns:
40
+ A CSV string representing the data.
41
+ Example result:
42
+ 'a','b','c'
43
+ 1,'z',2
44
+ 3,'w',''
45
+
46
+ Raises:
47
+ ValueError: The json_string is not a valid json or is not a single object.
48
+ """
49
+ try:
50
+ data = json.loads(json_string)
51
+ except json.JSONDecodeError as e:
52
+ raise ValueError(f"Invalid JSON string: {e}") from e
53
+
54
+ if not isinstance(data, dict):
55
+ raise ValueError("JSON data must be a single object")
56
+
57
+ headers = list(data.keys())
58
+ if columns_order is None:
59
+ columns_order = headers
60
+
61
+ if not all(col in headers for col in columns_order):
62
+ raise ValueError("columns_order must be a list of all column IDs")
63
+
64
+ if not columns_order:
65
+ return ""
66
+
67
+ csv_buffer = io.StringIO(newline="")
68
+ writer = csv.writer(csv_buffer, delimiter=separator, lineterminator="\n")
69
+
70
+ # Write header
71
+ writer.writerow([col for col in columns_order])
72
+
73
+ # Write row
74
+ cells_list = []
75
+ for col in columns_order:
76
+ value = ""
77
+ if col in data:
78
+ value = data.get(col)
79
+ if value is not None:
80
+ if isinstance(value, bool):
81
+ value = "true" if value else "false"
82
+ else:
83
+ value = str(value)
84
+ cells_list.append(value)
85
+ writer.writerow(cells_list)
86
+
87
+ return csv_buffer.getvalue()
@@ -0,0 +1,232 @@
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """For conversion of raw files to tool data.
16
+
17
+ Usage:
18
+ data = xspace_to_tool_data(xplane, tool, params)
19
+ data = tool_proto_to_tool_data(tool_proto, tool, params)
20
+ """
21
+
22
+ from __future__ import absolute_import
23
+ from __future__ import division
24
+ from __future__ import print_function
25
+
26
+ import logging
27
+
28
+ from xprof.convert import csv_writer
29
+ from xprof.convert import trace_events_json
30
+ from xprof.protobuf import trace_events_old_pb2
31
+ from xprof.convert import _pywrap_profiler_plugin
32
+
33
+
34
+ logger = logging.getLogger('tensorboard')
35
+
36
+
37
+ def process_raw_trace(raw_trace):
38
+ """Processes raw trace data and returns the UI data."""
39
+ trace = trace_events_old_pb2.Trace()
40
+ trace.ParseFromString(raw_trace)
41
+ return ''.join(trace_events_json.TraceEventsJsonStream(trace))
42
+
43
+
44
+ def xspace_to_tools_data_from_byte_string(xspace_byte_list, filenames, tool,
45
+ params):
46
+ """Helper function for getting an XSpace tool from a bytes string.
47
+
48
+ Args:
49
+ xspace_byte_list: A list of byte strings read from a XSpace proto file.
50
+ filenames: Names of the read files.
51
+ tool: A string of tool name.
52
+ params: user input parameters.
53
+
54
+ Returns:
55
+ Returns a string of tool data.
56
+ """
57
+ # pylint:disable=dangerous-default-value
58
+ def xspace_wrapper_func(xspace_arg, tool_arg, params={}):
59
+ return _pywrap_profiler_plugin.xspace_to_tools_data_from_byte_string(
60
+ xspace_arg, filenames, tool_arg, params)
61
+ # pylint:enable=dangerous-default-value
62
+
63
+ return xspace_to_tool_data(xspace_byte_list, tool, params,
64
+ xspace_wrapper_func)
65
+
66
+
67
+ def xspace_to_tool_names(xspace_paths):
68
+ """Converts XSpace to all the available tool names.
69
+
70
+ Args:
71
+ xspace_paths: A list of XSpace paths.
72
+
73
+ Returns:
74
+ Returns a list of tool names.
75
+ """
76
+ raw_data, success = _pywrap_profiler_plugin.xspace_to_tools_data(
77
+ xspace_paths, 'tool_names')
78
+ if success:
79
+ return [tool for tool in raw_data.decode().split(',')]
80
+ return []
81
+
82
+
83
+ def xspace_to_tool_data(
84
+ xspace_paths,
85
+ tool,
86
+ params,
87
+ xspace_wrapper_func=_pywrap_profiler_plugin.xspace_to_tools_data):
88
+ """Converts XSpace to tool data string.
89
+
90
+ Args:
91
+ xspace_paths: A list of XSpace paths.
92
+ tool: A string of tool name.
93
+ params: user input parameters.
94
+ xspace_wrapper_func: A callable that takes a list of strings and a tool and
95
+ returns the raw data. If failed, raw data contains the error message.
96
+
97
+ Returns:
98
+ Returns a string of tool data and the content type for the response.
99
+ """
100
+ if (tool[-1] == '^'):
101
+ old_tool = tool
102
+ tool = tool[:-1] # Remove the trailing '^'
103
+ logger.warning(
104
+ 'Received old tool format: %s; mapped to new format: %s', old_tool, tool
105
+ )
106
+ data = None
107
+ content_type = 'application/json'
108
+ # tqx: gViz output format
109
+ tqx = params.get('tqx', '')
110
+ options = {}
111
+ options['use_saved_result'] = params.get('use_saved_result', True)
112
+ if tool == 'trace_viewer':
113
+ # Trace viewer handles one host at a time.
114
+ assert len(xspace_paths) == 1
115
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
116
+ if success:
117
+ data = process_raw_trace(raw_data)
118
+ elif tool == 'trace_viewer@':
119
+ options = params.get('trace_viewer_options', {})
120
+ options['use_saved_result'] = params.get('use_saved_result', True)
121
+ options['hosts'] = params.get('hosts', [])
122
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
123
+ if success:
124
+ data = raw_data
125
+ elif tool == 'overview_page':
126
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
127
+ if success:
128
+ data = json_data
129
+ elif tool == 'input_pipeline_analyzer':
130
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
131
+ if success:
132
+ data = json_data
133
+ elif tool == 'framework_op_stats':
134
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
135
+ if success:
136
+ if tqx == 'out:csv':
137
+ data = csv_writer.json_to_csv(json_data)
138
+ else:
139
+ data = json_data
140
+ # Try legacy tool name: Handle backward compatibility with lower TF version
141
+ else:
142
+ # TODO(b/419013992): Remove this tool completely as it has been deprecated
143
+ legacy_tool = 'tensorflow_stats'
144
+ json_data, success = xspace_wrapper_func(
145
+ xspace_paths, legacy_tool, options
146
+ )
147
+ if success:
148
+ if tqx == 'out:csv':
149
+ data = csv_writer.json_to_csv(json_data)
150
+ else:
151
+ data = json_data
152
+ elif tool == 'kernel_stats':
153
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
154
+ if success:
155
+ if tqx == 'out:csv':
156
+ data = csv_writer.json_to_csv(json_data)
157
+ else:
158
+ data = json_data
159
+ elif tool == 'memory_profile':
160
+ # Memory profile handles one host at a time.
161
+ assert len(xspace_paths) == 1
162
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
163
+ if success:
164
+ data = raw_data
165
+ elif tool == 'pod_viewer':
166
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
167
+ if success:
168
+ data = raw_data
169
+ elif tool == 'op_profile':
170
+ options['group_by'] = params.get('group_by', 'program')
171
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
172
+ if success:
173
+ data = raw_data
174
+ elif tool == 'hlo_op_profile':
175
+ options['group_by'] = params.get('group_by', 'program')
176
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
177
+ if success:
178
+ data = raw_data
179
+ elif tool == 'hlo_stats':
180
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
181
+ if success:
182
+ data = json_data
183
+ elif tool == 'roofline_model':
184
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
185
+ if success:
186
+ data = json_data
187
+ elif tool == 'graph_viewer':
188
+ download_hlo_types = ['pb', 'pbtxt', 'json', 'short_txt', 'long_txt']
189
+ graph_html_type = 'graph'
190
+ options = params.get('graph_viewer_options', {})
191
+ options['use_saved_result'] = params.get('use_saved_result', True)
192
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
193
+ if success:
194
+ data = raw_data
195
+ content_type = 'text/plain'
196
+ data_type = options.get('type', '')
197
+ if (data_type in download_hlo_types):
198
+ content_type = 'application/octet-stream'
199
+ if data_type == graph_html_type:
200
+ content_type = 'text/html'
201
+ else:
202
+ # TODO(tf-profiler) Handle errors for other tools as well,
203
+ # to pass along the error message to client
204
+ if isinstance(raw_data, bytes):
205
+ raw_data = raw_data.decode('utf-8')
206
+ raise ValueError(raw_data)
207
+ elif tool == 'memory_viewer':
208
+ view_memory_allocation_timeline = params.get(
209
+ 'view_memory_allocation_timeline', False
210
+ )
211
+ options = {
212
+ 'module_name': params.get('module_name'),
213
+ 'view_memory_allocation_timeline': view_memory_allocation_timeline,
214
+ 'memory_space': params.get('memory_space', ''),
215
+ }
216
+ raw_data, success = xspace_wrapper_func(xspace_paths, tool, options)
217
+ if success:
218
+ data = raw_data
219
+ if view_memory_allocation_timeline:
220
+ content_type = 'text/html'
221
+ elif tool == 'megascale_stats':
222
+ options = {'host_name': params.get('host')}
223
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
224
+ if success:
225
+ data = json_data
226
+ elif tool == 'inference_profile':
227
+ json_data, success = xspace_wrapper_func(xspace_paths, tool, options)
228
+ if success:
229
+ data = json_data
230
+ else:
231
+ logger.warning('%s is not a known xplane tool', tool)
232
+ return data, content_type
@@ -0,0 +1,105 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """Converts trace events to JSON format consumed by catapult trace viewer."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import json
23
+ import six
24
+
25
+ # Values for type (ph) and s (scope) parameters in catapult trace format.
26
+ _TYPE_METADATA = 'M'
27
+ _TYPE_COMPLETE = 'X'
28
+ _TYPE_INSTANT = 'i'
29
+ _SCOPE_THREAD = 't'
30
+
31
+
32
+ class TraceEventsJsonStream(object):
33
+ """A streaming trace file in the format expected by catapult trace viewer.
34
+
35
+ Iterating over this yields a sequence of string chunks, so it is suitable for
36
+ returning in a werkzeug Response.
37
+ """
38
+
39
+ def __init__(self, proto):
40
+ """Create an iterable JSON stream over the supplied Trace.
41
+
42
+ Args:
43
+ proto: a tensorboard.profile.Trace protobuf
44
+ """
45
+ self._proto = proto
46
+
47
+ def _events(self):
48
+ """Iterator over all catapult trace events, as python values."""
49
+ for did, device in sorted(six.iteritems(self._proto.devices)):
50
+ if device.name:
51
+ yield dict(
52
+ ph=_TYPE_METADATA,
53
+ pid=did,
54
+ name='process_name',
55
+ args=dict(name=device.name))
56
+ yield dict(
57
+ ph=_TYPE_METADATA,
58
+ pid=did,
59
+ name='process_sort_index',
60
+ args=dict(sort_index=did))
61
+ for rid, resource in sorted(six.iteritems(device.resources)):
62
+ if resource.name:
63
+ yield dict(
64
+ ph=_TYPE_METADATA,
65
+ pid=did,
66
+ tid=rid,
67
+ name='thread_name',
68
+ args=dict(name=resource.name))
69
+ yield dict(
70
+ ph=_TYPE_METADATA,
71
+ pid=did,
72
+ tid=rid,
73
+ name='thread_sort_index',
74
+ args=dict(sort_index=rid))
75
+ for event in self._proto.trace_events:
76
+ yield self._event(event)
77
+
78
+ def _event(self, event):
79
+ """Converts a TraceEvent proto into a catapult trace event python value."""
80
+ result = dict(
81
+ pid=event.device_id,
82
+ tid=event.resource_id,
83
+ name=event.name,
84
+ ts=event.timestamp_ps / 1000000.0)
85
+ if event.duration_ps:
86
+ result['ph'] = _TYPE_COMPLETE
87
+ result['dur'] = event.duration_ps / 1000000.0
88
+ else:
89
+ result['ph'] = _TYPE_INSTANT
90
+ result['s'] = _SCOPE_THREAD
91
+ for key in dict(event.args):
92
+ if 'args' not in result:
93
+ result['args'] = {}
94
+ result['args'][key] = event.args[key]
95
+ return result
96
+
97
+ def __iter__(self):
98
+ """Returns an iterator of string chunks of a complete JSON document."""
99
+ yield '{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},\n'
100
+ yield '"traceEvents":[\n'
101
+ for event in self._events():
102
+ yield json.dumps(event, sort_keys=True)
103
+ yield ',\n'
104
+ # Add one fake event to avoid dealing with no-trailing-comma rule.
105
+ yield '{}]}\n'
@@ -0,0 +1,100 @@
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Tensorflow based mnist model and input datasets helper."""
16
+
17
+ _NUM_CLASSES = 10
18
+
19
+
20
+ def get_input_datasets(tf, use_bfloat16=False):
21
+ """Downloads the MNIST dataset and creates train and eval dataset objects.
22
+
23
+ Args:
24
+ tf: version specific tensorflow (either tf1 or tf2).
25
+ use_bfloat16: Boolean to determine if input should be cast to bfloat16
26
+
27
+ Returns:
28
+ Train dataset, eval dataset, test dataset and input shape.
29
+
30
+ """
31
+ # input image dimensions
32
+ img_rows, img_cols = 28, 28
33
+ cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32
34
+
35
+ # the data, split between train and test sets
36
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
37
+
38
+ if tf.keras.backend.image_data_format() == 'channels_first':
39
+ x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
40
+ x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
41
+ input_shape = (1, img_rows, img_cols)
42
+ else:
43
+ x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
44
+ x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
45
+ input_shape = (img_rows, img_cols, 1)
46
+
47
+ x_train = x_train.astype('float32')
48
+ x_test = x_test.astype('float32')
49
+ x_train /= 255
50
+ x_test /= 255
51
+
52
+ # convert class vectors to binary class matrices
53
+ y_train = tf.keras.utils.to_categorical(y_train, _NUM_CLASSES)
54
+ y_test = tf.keras.utils.to_categorical(y_test, _NUM_CLASSES)
55
+
56
+ # train dataset
57
+ train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
58
+ train_ds = train_ds.repeat()
59
+ train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
60
+ train_ds = train_ds.batch(64, drop_remainder=True)
61
+
62
+ # eval dataset
63
+ eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
64
+ eval_ds = eval_ds.repeat()
65
+ eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
66
+ eval_ds = eval_ds.batch(64, drop_remainder=True)
67
+
68
+ # test dataset
69
+ test_ds = tf.data.Dataset.from_tensor_slices((x_test))
70
+ test_ds = test_ds.repeat()
71
+ test_ds = test_ds.map(lambda x: tf.cast(x, cast_dtype))
72
+ test_ds = test_ds.batch(64, drop_remainder=True)
73
+
74
+ return train_ds, eval_ds, test_ds, input_shape
75
+
76
+
77
+ def get_model(tf, input_shape):
78
+ """Builds a Sequential CNN model to recognize MNIST digits.
79
+
80
+ Args:
81
+ tf: version specific tensorflow (either tf1 or tf2).
82
+ input_shape: Shape of the input depending on the `image_data_format`.
83
+
84
+ Returns:
85
+ a Keras model
86
+
87
+ """
88
+ # Define a CNN model to recognize MNIST digits.
89
+ model = tf.keras.models.Sequential()
90
+ model.add(
91
+ tf.keras.layers.Conv2D(
92
+ 32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
93
+ model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
94
+ model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
95
+ model.add(tf.keras.layers.Dropout(0.25))
96
+ model.add(tf.keras.layers.Flatten())
97
+ model.add(tf.keras.layers.Dense(128, activation='relu'))
98
+ model.add(tf.keras.layers.Dropout(0.5))
99
+ model.add(tf.keras.layers.Dense(_NUM_CLASSES, activation='softmax'))
100
+ return model
@@ -0,0 +1,40 @@
1
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Programmatic functionality for tensorflow profile session."""
16
+
17
+ import types
18
+
19
+
20
+ class TensorflowProfilerSession:
21
+ """context manager for Tensorflow programmatic profile session."""
22
+
23
+ def __init__(
24
+ self, tf: types.ModuleType, log_dir: str, python_tracer_level: int = 1
25
+ ):
26
+ """tf object is version dependent object."""
27
+ self.profiler = tf.profiler.experimental
28
+ self.log_dir = log_dir
29
+ self.python_tracer_level = python_tracer_level
30
+
31
+ def __enter__(self):
32
+ """Starts profile session and serializes data to temp directory."""
33
+ options = self.profiler.ProfilerOptions(
34
+ host_tracer_level=2, python_tracer_level=self.python_tracer_level
35
+ )
36
+ self.profiler.start(self.log_dir, options=options)
37
+
38
+ def __exit__(self, exc_type, exc_value, traceback):
39
+ """Ends current profile session and verifies test expectations."""
40
+ self.profiler.stop()