jax-hpc-profiler 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ from .create_argparse import create_argparser
2
+ from .plotting import plot_strong_scaling, plot_weak_scaling
3
+ from .timer import Timer
4
+ from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
5
+
6
+ __all__ = [
7
+ 'create_argparser', 'plot_strong_scaling', 'plot_weak_scaling', 'Timer',
8
+ 'clean_up_csv', 'concatenate_csvs', 'plot_with_pdims_strategy'
9
+ ]
@@ -0,0 +1,158 @@
1
+ import argparse
2
+
3
+
4
+ def create_argparser():
5
+ """
6
+ Create argument parser for the HPC Plotter package.
7
+
8
+ Returns
9
+ -------
10
+ argparse.Namespace
11
+ Parsed and validated arguments.
12
+ """
13
+ parser = argparse.ArgumentParser(
14
+ description='HPC Plotter for benchmarking data')
15
+
16
+ # Group for concatenation to ensure mutually exclusive behavior
17
+ subparsers = parser.add_subparsers(dest='command', required=True)
18
+
19
+ concat_parser = subparsers.add_parser('concat',
20
+ help='Concatenate CSV files')
21
+ concat_parser.add_argument('input',
22
+ type=str,
23
+ help='Input directory for concatenation')
24
+ concat_parser.add_argument('output',
25
+ type=str,
26
+ help='Output directory for concatenation')
27
+
28
+ # Arguments for plotting
29
+ plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
30
+ plot_parser.add_argument('-f',
31
+ '--csv_files',
32
+ nargs='+',
33
+ help='List of CSV files to plot',
34
+ required=True)
35
+ plot_parser.add_argument('-g',
36
+ '--gpus',
37
+ nargs='*',
38
+ type=int,
39
+ help='List of number of GPUs to plot')
40
+ plot_parser.add_argument('-d',
41
+ '--data_size',
42
+ nargs='*',
43
+ type=int,
44
+ help='List of data sizes to plot')
45
+
46
+ # pdims related arguments
47
+ plot_parser.add_argument('-fd',
48
+ '--filter_pdims',
49
+ nargs='*',
50
+ help='List of pdims to filter, e.g., 1x4 2x2 4x8')
51
+ plot_parser.add_argument(
52
+ '-ps',
53
+ '--pdim_strategy',
54
+ choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
55
+ nargs='*',
56
+ default=['plot_fastest'],
57
+ help='Strategy for plotting pdims')
58
+
59
+ # Function and precision related arguments
60
+ plot_parser.add_argument(
61
+ '-pr',
62
+ '--precision',
63
+ choices=['float32', 'float64'],
64
+ default=['float32', 'float64'],
65
+ nargs='*',
66
+ help='Precision to filter by (float32 or float64)')
67
+ plot_parser.add_argument('-fn',
68
+ '--function_name',
69
+ nargs='+',
70
+ default=['FFT'],
71
+ help='Function names to filter')
72
+
73
+ # Time or memory related arguments
74
+ plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
75
+ plotting_group.add_argument('-pt',
76
+ '--plot_times',
77
+ nargs='*',
78
+ choices=[
79
+ 'jit_time', 'min_time', 'max_time',
80
+ 'mean_time', 'std_time', 'last_time'
81
+ ],
82
+ help='Time columns to plot')
83
+ plotting_group.add_argument('-pm',
84
+ '--plot_memory',
85
+ nargs='*',
86
+ choices=[
87
+ 'generated_code', 'argument_size',
88
+ 'output_size', 'temp_size'
89
+ ],
90
+ help='Memory columns to plot')
91
+ plot_parser.add_argument('-mu',
92
+ '--memory_units',
93
+ default='GB',
94
+ help='Memory units to plot (KB, MB, GB, TB)')
95
+
96
+ # Plot customization arguments
97
+ plot_parser.add_argument('-fs',
98
+ '--figure_size',
99
+ nargs=2,
100
+ type=int,
101
+ help='Figure size')
102
+ plot_parser.add_argument('-o',
103
+ '--output',
104
+ help='Output file (if none then only show plot)',
105
+ default=None)
106
+ plot_parser.add_argument('-db',
107
+ '--dark_bg',
108
+ action='store_true',
109
+ help='Use dark background for plotting')
110
+ plot_parser.add_argument('-pd',
111
+ '--print_decompositions',
112
+ action='store_true',
113
+ help='Print decompositions on plot')
114
+
115
+ # Backend related arguments
116
+ plot_parser.add_argument('-b',
117
+ '--backends',
118
+ nargs='*',
119
+ default=['MPI', 'NCCL', 'MPI4JAX'],
120
+ help='List of backends to include')
121
+
122
+ # Scaling type argument
123
+ plot_parser.add_argument('-sc',
124
+ '--scaling',
125
+ choices=['Weak', 'Strong'],
126
+ required=True,
127
+ help='Scaling type (Weak or Strong)')
128
+
129
+ # Label customization argument
130
+ plot_parser.add_argument(
131
+ '-l',
132
+ '--label_text',
133
+ type=str,
134
+ help=
135
+ 'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
136
+ default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
+
138
+ args = parser.parse_args()
139
+
140
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
141
+ print(
142
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
143
+ )
144
+ args.pdim_strategy = ['plot_all']
145
+
146
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
147
+ print(
148
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
149
+ )
150
+ args.pdim_strategy = ['plot_fastest']
151
+ if args.plot_times is not None:
152
+ args.plot_columns = args.plot_times
153
+ elif args.plot_memory is not None:
154
+ args.plot_columns = args.plot_memory
155
+ else:
156
+ raise ValueError('Either plot_times or plot_memory should be provided')
157
+
158
+ return args
@@ -0,0 +1,57 @@
1
+ import sys
2
+
3
+ from .create_argparse import create_argparser
4
+ from .plotting import plot_strong_scaling, plot_weak_scaling
5
+ from .utils import clean_up_csv, concatenate_csvs
6
+
7
+
8
+ def main():
9
+ args = create_argparser()
10
+
11
+ if args.command == 'concat':
12
+ input_dir, output_dir = args.input, args.output
13
+ concatenate_csvs(input_dir, output_dir)
14
+ elif args.command == 'plot':
15
+ dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
16
+ args.csv_files, args.precision, args.function_name, args.gpus,
17
+ args.data_size, args.filter_pdims, args.pdim_strategy,
18
+ args.backends,args.memory_units)
19
+ if len(dataframes) == 0:
20
+ print(f"No dataframes found for the given arguments. Exiting...")
21
+ sys.exit(1)
22
+ print(
23
+ f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
24
+ )
25
+ # filter back the requested data sizes and gpus
26
+ args.gpus = [gpu for gpu in args.gpus if gpu in available_gpu_counts]
27
+ args.data_size = [
28
+ data_size for data_size in args.data_size
29
+ if data_size in available_data_sizes
30
+ ]
31
+ if len(args.gpus) == 0:
32
+ print(
33
+ f"No dataframes found for the given GPUs. Exiting...")
34
+ sys.exit(1)
35
+ if len(args.data_size) == 0:
36
+ print(
37
+ f"No dataframes found for the given data sizes. Exiting...")
38
+ sys.exit(1)
39
+
40
+ if args.scaling == 'Weak':
41
+ plot_weak_scaling(dataframes, args.gpus, args.figure_size,
42
+ args.output, args.dark_bg,
43
+ args.print_decompositions, args.backends,
44
+ args.precision, args.function_name,
45
+ args.plot_columns, args.memory_units,
46
+ args.label_text, args.pdim_strategy)
47
+ elif args.scaling == 'Strong':
48
+ plot_strong_scaling(dataframes, args.data_size, args.figure_size,
49
+ args.output, args.dark_bg,
50
+ args.print_decompositions, args.backends,
51
+ args.precision, args.function_name,
52
+ args.plot_columns, args.memory_units,
53
+ args.label_text, args.pdim_strategy)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
@@ -0,0 +1,214 @@
1
+ from itertools import product
2
+ from typing import Dict, List, Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ from matplotlib.axes import Axes
8
+ from matplotlib.patches import FancyBboxPatch
9
+
10
+ from .utils import inspect_df, plot_with_pdims_strategy
11
+
12
+ np.seterr(divide='ignore')
13
+ plt.rcParams.update({'font.size': 10})
14
+
15
+
16
+ def configure_axes(ax: Axes,
17
+ x_values: List[int],
18
+ y_values: List[float],
19
+ xlabel: str,
20
+ title: str,
21
+ plotting_memory: bool = False,
22
+ memory_units: str = 'bytes'):
23
+ """
24
+ Configure the axes for the plot.
25
+
26
+ Parameters
27
+ ----------
28
+ ax : Axes
29
+ The axes to configure.
30
+ x_values : List[int]
31
+ The x-axis values.
32
+ y_values : List[float]
33
+ The y-axis values.
34
+ xlabel : str
35
+ The label for the x-axis.
36
+ """
37
+ ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
38
+ f2 = lambda x: np.log2(x)
39
+ g2 = lambda x: 2**x
40
+ ax.set_xlim([min(x_values), max(x_values)])
41
+ y_min, y_max = min(y_values) * 0.9, max(y_values) * 1.1
42
+ ax.set_title(title)
43
+ ax.set_ylim([y_min, y_max])
44
+ ax.set_xscale('function', functions=(f2, g2))
45
+ if not plotting_memory:
46
+ ax.set_yscale('symlog')
47
+ time_ticks = [
48
+ 10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
49
+ int(np.ceil(np.log10(y_max))))
50
+ ]
51
+ ax.set_yticks(time_ticks)
52
+ ax.set_xticks(x_values)
53
+ ax.set_xlabel(xlabel)
54
+ ax.set_ylabel(ylabel)
55
+ for x_value in x_values:
56
+ ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
57
+ ax.legend(loc='lower center',
58
+ bbox_to_anchor=(0.5, 0.05),
59
+ ncol=4,
60
+ prop={'size': 14})
61
+
62
+
63
+ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
64
+ fixed_sizes: List[int],
65
+ size_column: str,
66
+ fixed_column: str,
67
+ xlabel: str,
68
+ title: str,
69
+ figure_size: tuple = (6, 4),
70
+ output: Optional[str] = None,
71
+ dark_bg: bool = False,
72
+ print_decompositions: bool = False,
73
+ backends: List[str] = ['NCCL'],
74
+ precisions: List[str] = ['float32'],
75
+ functions: List[str] | None = None,
76
+ plot_columns: List[str] = ['mean_time'],
77
+ memory_units: str = 'bytes',
78
+ label_text: str = 'plot',
79
+ pdims_strategy: str = 'plot_fastest'):
80
+ """
81
+ General scaling plot function based on the number of GPUs or data size.
82
+
83
+ Parameters
84
+ ----------
85
+ dataframes : Dict[str, pd.DataFrame]
86
+ Dictionary of method names to dataframes.
87
+ fixed_sizes : List[int]
88
+ List of fixed sizes (data or GPUs) to plot.
89
+ size_column : str
90
+ Column name for the size axis ('x' for weak scaling, 'gpus' for strong scaling).
91
+ fixed_column : str
92
+ Column name for the fixed axis ('gpus' for weak scaling, 'x' for strong scaling).
93
+ xlabel : str
94
+ Label for the x-axis.
95
+ figure_size : tuple, optional
96
+ Size of the figure, by default (6, 4).
97
+ output : Optional[str], optional
98
+ Output file to save the plot, by default None.
99
+ dark_bg : bool, optional
100
+ Whether to use dark background for the plot, by default False.
101
+ print_decompositions : bool, optional
102
+ Whether to print decompositions on the plot, by default False.
103
+ backends : Optional[List[str]], optional
104
+ List of backends to include, by default None.
105
+ pdims_strategy : str, optional
106
+ Strategy for plotting pdims ('plot_all' or 'plot_fastest'), by default 'plot_fastest'.
107
+ """
108
+ if dark_bg:
109
+ plt.style.use('dark_background')
110
+
111
+ num_subplots = len(fixed_sizes)
112
+ num_rows = int(np.ceil(np.sqrt(num_subplots)))
113
+ num_cols = int(np.ceil(num_subplots / num_rows))
114
+
115
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=figure_size)
116
+ if num_subplots > 1:
117
+ axs = axs.flatten()
118
+ else:
119
+ axs = [axs]
120
+
121
+ for i, fixed_size in enumerate(fixed_sizes):
122
+ ax: Axes = axs[i]
123
+
124
+ for method, df in dataframes.items():
125
+ x_values = []
126
+ y_values = []
127
+
128
+ filtered_method_df = df[df[fixed_column] == int(fixed_size)]
129
+ if filtered_method_df.empty:
130
+ continue
131
+ filtered_method_df = filtered_method_df.sort_values(
132
+ by=[size_column])
133
+ functions = pd.unique(filtered_method_df['function']
134
+ ) if functions is None else functions
135
+ combinations = product(backends, precisions, functions,
136
+ plot_columns)
137
+
138
+ for backend, precision, function, plot_column in combinations:
139
+ filtered_params_df = filtered_method_df[
140
+ (filtered_method_df['backend'] == backend)
141
+ & (filtered_method_df['precision'] == precision) &
142
+ (filtered_method_df['function'] == function)]
143
+ if filtered_params_df.empty:
144
+ continue
145
+ x_vals, y_vals = plot_with_pdims_strategy(
146
+ ax, filtered_params_df, method, pdims_strategy,
147
+ print_decompositions, size_column, plot_column, label_text)
148
+
149
+ x_values.extend(x_vals)
150
+ y_values.extend(y_vals)
151
+
152
+ plotting_memory = 'time' not in plot_columns[0].lower()
153
+ configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
154
+ plotting_memory, memory_units)
155
+
156
+ for i in range(num_subplots, num_rows * num_cols):
157
+ fig.delaxes(axs[i])
158
+
159
+ fig.tight_layout()
160
+ rect = FancyBboxPatch((0.1, 0.1),
161
+ 0.8,
162
+ 0.8,
163
+ boxstyle="round,pad=0.02",
164
+ ec="black",
165
+ fc="none")
166
+ fig.patches.append(rect)
167
+ if output is None:
168
+ plt.show()
169
+ else:
170
+ plt.savefig(output, bbox_inches='tight', transparent=False)
171
+
172
+
173
+ def plot_strong_scaling(dataframes: Dict[str, pd.DataFrame],
174
+ fixed_data_size: List[int],
175
+ figure_size: tuple = (6, 4),
176
+ output: Optional[str] = None,
177
+ dark_bg: bool = False,
178
+ print_decompositions: bool = False,
179
+ backends: List[str] = ['NCCL'],
180
+ precisions: List[str] = ['float32'],
181
+ functions: List[str] | None = None,
182
+ plot_columns: List[str] = ['mean_time'],
183
+ memory_units: str = 'bytes',
184
+ label_text: str = 'plot',
185
+ pdims_strategy: str = 'plot_fastest'):
186
+ """
187
+ Plot strong scaling based on the number of GPUs.
188
+ """
189
+ plot_scaling(dataframes, fixed_data_size, 'gpus', 'x', 'Number of GPUs',
190
+ 'Data size', figure_size, output, dark_bg,
191
+ print_decompositions, backends, precisions, functions,
192
+ plot_columns, memory_units, label_text, pdims_strategy)
193
+
194
+
195
+ def plot_weak_scaling(dataframes: Dict[str, pd.DataFrame],
196
+ fixed_gpu_size: List[int],
197
+ figure_size: tuple = (6, 4),
198
+ output: Optional[str] = None,
199
+ dark_bg: bool = False,
200
+ print_decompositions: bool = False,
201
+ backends: List[str] = ['NCCL'],
202
+ precisions: List[str] = ['float32'],
203
+ functions: List[str] | None = None,
204
+ plot_columns: List[str] = ['mean_time'],
205
+ memory_units: str = 'bytes',
206
+ label_text: str = 'plot',
207
+ pdims_strategy: str = 'plot_fastest'):
208
+ """
209
+ Plot weak scaling based on the data size.
210
+ """
211
+ plot_scaling(dataframes, fixed_gpu_size, 'x', 'gpus', 'Data size',
212
+ 'Number of GPUs', figure_size, output, dark_bg,
213
+ print_decompositions, backends, precisions, functions,
214
+ plot_columns, memory_units, label_text, pdims_strategy)
@@ -0,0 +1,185 @@
1
+ import os
2
+ import time
3
+ from functools import partial
4
+ from typing import Any, Callable, List
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from jax import make_jaxpr
10
+ from jax.experimental.shard_map import shard_map
11
+ from jax.sharding import Mesh, NamedSharding
12
+ from jax.sharding import PartitionSpec as P
13
+ from tabulate import tabulate
14
+
15
+
16
+ class Timer:
17
+
18
+ def __init__(self, save_jaxpr=False):
19
+ self.jit_time = None
20
+ self.times = []
21
+ self.profiling_data = {}
22
+ self.compiled_code = {}
23
+ self.save_jaxpr = save_jaxpr
24
+
25
+ def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
26
+ start = time.perf_counter()
27
+ out = jax.jit(fun)(*args)
28
+ if ndarray_arg is None:
29
+ out.block_until_ready()
30
+ else:
31
+ out[ndarray_arg].block_until_ready()
32
+ end = time.perf_counter()
33
+ self.jit_time = (end - start) * 1e3
34
+
35
+ if self.save_jaxpr:
36
+ jaxpr = make_jaxpr(fun)(*args)
37
+ self.compiled_code["JAXPR"] = jaxpr.pretty_print()
38
+
39
+ lowered = jax.jit(fun).lower(*args)
40
+ compiled = lowered.compile()
41
+ memory_analysis = compiled.memory_analysis()
42
+ self.compiled_code["LOWERED"] = lowered.as_text()
43
+ self.compiled_code["COMPILED"] = compiled.as_text()
44
+ self.profiling_data["FLOPS"] = compiled.cost_analysis()[0]['flops']
45
+ self.profiling_data[
46
+ "generated_code"] = memory_analysis.generated_code_size_in_bytes
47
+ self.profiling_data[
48
+ "argument_size"] = memory_analysis.argument_size_in_bytes
49
+ self.profiling_data[
50
+ "output_size"] = memory_analysis.output_size_in_bytes
51
+ self.profiling_data["temp_size"] = memory_analysis.temp_size_in_bytes
52
+
53
+ return out
54
+
55
+ def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
56
+ start = time.perf_counter()
57
+ out = fun(*args)
58
+ if ndarray_arg is None:
59
+ out.block_until_ready()
60
+ else:
61
+ out[ndarray_arg].block_until_ready()
62
+ end = time.perf_counter()
63
+ self.times.append((end - start) * 1e3)
64
+ return out
65
+
66
+ def _get_mean_times(self, times_array: jnp.ndarray,
67
+ sharding: NamedSharding):
68
+ mesh = sharding.mesh
69
+ specs = sharding.spec
70
+ valid_letters = [letter for letter in specs if letter is not None]
71
+ assert len(valid_letters
72
+ ) > 0, "Sharding was provided but with no partition specs"
73
+
74
+ @partial(shard_map,
75
+ mesh=mesh,
76
+ in_specs=specs,
77
+ out_specs=P(),
78
+ check_rep=False)
79
+ def get_mean_times(times):
80
+ mean = jax.lax.pmean(times, axis_name=valid_letters[0])
81
+ for axis_name in valid_letters[1:]:
82
+ mean = jax.lax.pmean(mean, axis_name=axis_name)
83
+ return mean
84
+
85
+ times_array = get_mean_times(times_array)
86
+ times_array.block_until_ready()
87
+ return times_array
88
+
89
+ def report(self,
90
+ csv_filename: str,
91
+ function: str,
92
+ precision: str,
93
+ x: int,
94
+ y: int,
95
+ z: int,
96
+ px: int,
97
+ py: int,
98
+ backend: str,
99
+ nodes: int,
100
+ sharding: NamedSharding | None = None,
101
+ md_filename: str | None = None,
102
+ extra_info: dict = {}):
103
+ times_array = jnp.array(self.times)
104
+
105
+ if md_filename is None:
106
+ dirname, filename = os.path.dirname(csv_filename), os.path.splitext(os.path.basename(csv_filename))[0]
107
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
108
+ print(f"report_folder: {report_folder} csv_filename: {csv_filename}")
109
+ os.makedirs(report_folder, exist_ok=True)
110
+ md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
111
+
112
+ if sharding is not None:
113
+ times_array = self._get_mean_times(times_array, sharding)
114
+
115
+ times_array = np.array(times_array)
116
+ min_time = np.min(times_array)
117
+ max_time = np.max(times_array)
118
+ mean_time = np.mean(times_array)
119
+ std_time = np.std(times_array)
120
+ last_time = times_array[-1]
121
+
122
+ flops = self.profiling_data["FLOPS"]
123
+ generated_code = self.profiling_data["generated_code"]
124
+ argument_size = self.profiling_data["argument_size"]
125
+ output_size = self.profiling_data["output_size"]
126
+ temp_size = self.profiling_data["temp_size"]
127
+
128
+ csv_line = (
129
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
130
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
131
+ f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
132
+ )
133
+
134
+ with open(csv_filename, 'a') as f:
135
+ f.write(csv_line)
136
+
137
+ param_dict = {
138
+ "Function": function,
139
+ "Precision": precision,
140
+ "X": x,
141
+ "Y": y,
142
+ "Z": z,
143
+ "PX": px,
144
+ "PY": py,
145
+ "Backend": backend,
146
+ "Nodes": nodes,
147
+ }
148
+ param_dict.update(extra_info)
149
+ profiling_result = {
150
+ "JIT Time": self.jit_time,
151
+ "Min Time": min_time,
152
+ "Max Time": max_time,
153
+ "Mean Time": mean_time,
154
+ "Std Time": std_time,
155
+ "Last Time": last_time,
156
+ "Generated Code": generated_code,
157
+ "Argument Size": argument_size,
158
+ "Output Size": output_size,
159
+ "Temporary Size": temp_size,
160
+ "FLOPS": self.profiling_data["FLOPS"]
161
+ }
162
+
163
+ with open(md_filename, 'w') as f:
164
+ f.write(f"# Reporting for {function}\n")
165
+ f.write(f"## Parameters\n")
166
+ f.write(tabulate(param_dict.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
167
+ f.write("\n---\n")
168
+ f.write(f"## Profiling Data\n")
169
+ f.write(tabulate(profiling_result.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
170
+ f.write("\n---\n")
171
+ f.write(f"## Compiled Code\n")
172
+ f.write(f"```hlo\n")
173
+ f.write(self.compiled_code["COMPILED"])
174
+ f.write(f"\n```\n")
175
+ f.write("\n---\n")
176
+ f.write(f"## Lowered Code\n")
177
+ f.write(f"```hlo\n")
178
+ f.write(self.compiled_code["LOWERED"])
179
+ f.write(f"\n```\n")
180
+ f.write("\n---\n")
181
+ if self.save_jaxpr:
182
+ f.write(f"## JAXPR\n")
183
+ f.write(f"```haskel\n")
184
+ f.write(self.compiled_code["JAXPR"])
185
+ f.write(f"\n```\n")