jax-hpc-profiler 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +9 -0
- jax_hpc_profiler/create_argparse.py +158 -0
- jax_hpc_profiler/main.py +57 -0
- jax_hpc_profiler/plotting.py +214 -0
- jax_hpc_profiler/timer.py +185 -0
- jax_hpc_profiler/utils.py +396 -0
- jax_hpc_profiler-0.2.0.dist-info/LICENSE +674 -0
- jax_hpc_profiler-0.2.0.dist-info/METADATA +847 -0
- jax_hpc_profiler-0.2.0.dist-info/RECORD +12 -0
- jax_hpc_profiler-0.2.0.dist-info/WHEEL +5 -0
- jax_hpc_profiler-0.2.0.dist-info/entry_points.txt +2 -0
- jax_hpc_profiler-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .create_argparse import create_argparser
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
3
|
+
from .timer import Timer
|
|
4
|
+
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'create_argparser', 'plot_strong_scaling', 'plot_weak_scaling', 'Timer',
|
|
8
|
+
'clean_up_csv', 'concatenate_csvs', 'plot_with_pdims_strategy'
|
|
9
|
+
]
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_argparser():
|
|
5
|
+
"""
|
|
6
|
+
Create argument parser for the HPC Plotter package.
|
|
7
|
+
|
|
8
|
+
Returns
|
|
9
|
+
-------
|
|
10
|
+
argparse.Namespace
|
|
11
|
+
Parsed and validated arguments.
|
|
12
|
+
"""
|
|
13
|
+
parser = argparse.ArgumentParser(
|
|
14
|
+
description='HPC Plotter for benchmarking data')
|
|
15
|
+
|
|
16
|
+
# Group for concatenation to ensure mutually exclusive behavior
|
|
17
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
18
|
+
|
|
19
|
+
concat_parser = subparsers.add_parser('concat',
|
|
20
|
+
help='Concatenate CSV files')
|
|
21
|
+
concat_parser.add_argument('input',
|
|
22
|
+
type=str,
|
|
23
|
+
help='Input directory for concatenation')
|
|
24
|
+
concat_parser.add_argument('output',
|
|
25
|
+
type=str,
|
|
26
|
+
help='Output directory for concatenation')
|
|
27
|
+
|
|
28
|
+
# Arguments for plotting
|
|
29
|
+
plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
|
|
30
|
+
plot_parser.add_argument('-f',
|
|
31
|
+
'--csv_files',
|
|
32
|
+
nargs='+',
|
|
33
|
+
help='List of CSV files to plot',
|
|
34
|
+
required=True)
|
|
35
|
+
plot_parser.add_argument('-g',
|
|
36
|
+
'--gpus',
|
|
37
|
+
nargs='*',
|
|
38
|
+
type=int,
|
|
39
|
+
help='List of number of GPUs to plot')
|
|
40
|
+
plot_parser.add_argument('-d',
|
|
41
|
+
'--data_size',
|
|
42
|
+
nargs='*',
|
|
43
|
+
type=int,
|
|
44
|
+
help='List of data sizes to plot')
|
|
45
|
+
|
|
46
|
+
# pdims related arguments
|
|
47
|
+
plot_parser.add_argument('-fd',
|
|
48
|
+
'--filter_pdims',
|
|
49
|
+
nargs='*',
|
|
50
|
+
help='List of pdims to filter, e.g., 1x4 2x2 4x8')
|
|
51
|
+
plot_parser.add_argument(
|
|
52
|
+
'-ps',
|
|
53
|
+
'--pdim_strategy',
|
|
54
|
+
choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
|
|
55
|
+
nargs='*',
|
|
56
|
+
default=['plot_fastest'],
|
|
57
|
+
help='Strategy for plotting pdims')
|
|
58
|
+
|
|
59
|
+
# Function and precision related arguments
|
|
60
|
+
plot_parser.add_argument(
|
|
61
|
+
'-pr',
|
|
62
|
+
'--precision',
|
|
63
|
+
choices=['float32', 'float64'],
|
|
64
|
+
default=['float32', 'float64'],
|
|
65
|
+
nargs='*',
|
|
66
|
+
help='Precision to filter by (float32 or float64)')
|
|
67
|
+
plot_parser.add_argument('-fn',
|
|
68
|
+
'--function_name',
|
|
69
|
+
nargs='+',
|
|
70
|
+
default=['FFT'],
|
|
71
|
+
help='Function names to filter')
|
|
72
|
+
|
|
73
|
+
# Time or memory related arguments
|
|
74
|
+
plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
|
|
75
|
+
plotting_group.add_argument('-pt',
|
|
76
|
+
'--plot_times',
|
|
77
|
+
nargs='*',
|
|
78
|
+
choices=[
|
|
79
|
+
'jit_time', 'min_time', 'max_time',
|
|
80
|
+
'mean_time', 'std_time', 'last_time'
|
|
81
|
+
],
|
|
82
|
+
help='Time columns to plot')
|
|
83
|
+
plotting_group.add_argument('-pm',
|
|
84
|
+
'--plot_memory',
|
|
85
|
+
nargs='*',
|
|
86
|
+
choices=[
|
|
87
|
+
'generated_code', 'argument_size',
|
|
88
|
+
'output_size', 'temp_size'
|
|
89
|
+
],
|
|
90
|
+
help='Memory columns to plot')
|
|
91
|
+
plot_parser.add_argument('-mu',
|
|
92
|
+
'--memory_units',
|
|
93
|
+
default='GB',
|
|
94
|
+
help='Memory units to plot (KB, MB, GB, TB)')
|
|
95
|
+
|
|
96
|
+
# Plot customization arguments
|
|
97
|
+
plot_parser.add_argument('-fs',
|
|
98
|
+
'--figure_size',
|
|
99
|
+
nargs=2,
|
|
100
|
+
type=int,
|
|
101
|
+
help='Figure size')
|
|
102
|
+
plot_parser.add_argument('-o',
|
|
103
|
+
'--output',
|
|
104
|
+
help='Output file (if none then only show plot)',
|
|
105
|
+
default=None)
|
|
106
|
+
plot_parser.add_argument('-db',
|
|
107
|
+
'--dark_bg',
|
|
108
|
+
action='store_true',
|
|
109
|
+
help='Use dark background for plotting')
|
|
110
|
+
plot_parser.add_argument('-pd',
|
|
111
|
+
'--print_decompositions',
|
|
112
|
+
action='store_true',
|
|
113
|
+
help='Print decompositions on plot')
|
|
114
|
+
|
|
115
|
+
# Backend related arguments
|
|
116
|
+
plot_parser.add_argument('-b',
|
|
117
|
+
'--backends',
|
|
118
|
+
nargs='*',
|
|
119
|
+
default=['MPI', 'NCCL', 'MPI4JAX'],
|
|
120
|
+
help='List of backends to include')
|
|
121
|
+
|
|
122
|
+
# Scaling type argument
|
|
123
|
+
plot_parser.add_argument('-sc',
|
|
124
|
+
'--scaling',
|
|
125
|
+
choices=['Weak', 'Strong'],
|
|
126
|
+
required=True,
|
|
127
|
+
help='Scaling type (Weak or Strong)')
|
|
128
|
+
|
|
129
|
+
# Label customization argument
|
|
130
|
+
plot_parser.add_argument(
|
|
131
|
+
'-l',
|
|
132
|
+
'--label_text',
|
|
133
|
+
type=str,
|
|
134
|
+
help=
|
|
135
|
+
'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
|
|
136
|
+
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
|
|
137
|
+
|
|
138
|
+
args = parser.parse_args()
|
|
139
|
+
|
|
140
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
141
|
+
print(
|
|
142
|
+
"Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
|
|
143
|
+
)
|
|
144
|
+
args.pdim_strategy = ['plot_all']
|
|
145
|
+
|
|
146
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
147
|
+
print(
|
|
148
|
+
"Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
|
|
149
|
+
)
|
|
150
|
+
args.pdim_strategy = ['plot_fastest']
|
|
151
|
+
if args.plot_times is not None:
|
|
152
|
+
args.plot_columns = args.plot_times
|
|
153
|
+
elif args.plot_memory is not None:
|
|
154
|
+
args.plot_columns = args.plot_memory
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
157
|
+
|
|
158
|
+
return args
|
jax_hpc_profiler/main.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from .create_argparse import create_argparser
|
|
4
|
+
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
5
|
+
from .utils import clean_up_csv, concatenate_csvs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
args = create_argparser()
|
|
10
|
+
|
|
11
|
+
if args.command == 'concat':
|
|
12
|
+
input_dir, output_dir = args.input, args.output
|
|
13
|
+
concatenate_csvs(input_dir, output_dir)
|
|
14
|
+
elif args.command == 'plot':
|
|
15
|
+
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
16
|
+
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
17
|
+
args.data_size, args.filter_pdims, args.pdim_strategy,
|
|
18
|
+
args.backends,args.memory_units)
|
|
19
|
+
if len(dataframes) == 0:
|
|
20
|
+
print(f"No dataframes found for the given arguments. Exiting...")
|
|
21
|
+
sys.exit(1)
|
|
22
|
+
print(
|
|
23
|
+
f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
|
|
24
|
+
)
|
|
25
|
+
# filter back the requested data sizes and gpus
|
|
26
|
+
args.gpus = [gpu for gpu in args.gpus if gpu in available_gpu_counts]
|
|
27
|
+
args.data_size = [
|
|
28
|
+
data_size for data_size in args.data_size
|
|
29
|
+
if data_size in available_data_sizes
|
|
30
|
+
]
|
|
31
|
+
if len(args.gpus) == 0:
|
|
32
|
+
print(
|
|
33
|
+
f"No dataframes found for the given GPUs. Exiting...")
|
|
34
|
+
sys.exit(1)
|
|
35
|
+
if len(args.data_size) == 0:
|
|
36
|
+
print(
|
|
37
|
+
f"No dataframes found for the given data sizes. Exiting...")
|
|
38
|
+
sys.exit(1)
|
|
39
|
+
|
|
40
|
+
if args.scaling == 'Weak':
|
|
41
|
+
plot_weak_scaling(dataframes, args.gpus, args.figure_size,
|
|
42
|
+
args.output, args.dark_bg,
|
|
43
|
+
args.print_decompositions, args.backends,
|
|
44
|
+
args.precision, args.function_name,
|
|
45
|
+
args.plot_columns, args.memory_units,
|
|
46
|
+
args.label_text, args.pdim_strategy)
|
|
47
|
+
elif args.scaling == 'Strong':
|
|
48
|
+
plot_strong_scaling(dataframes, args.data_size, args.figure_size,
|
|
49
|
+
args.output, args.dark_bg,
|
|
50
|
+
args.print_decompositions, args.backends,
|
|
51
|
+
args.precision, args.function_name,
|
|
52
|
+
args.plot_columns, args.memory_units,
|
|
53
|
+
args.label_text, args.pdim_strategy)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == "__main__":
|
|
57
|
+
main()
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from matplotlib.axes import Axes
|
|
8
|
+
from matplotlib.patches import FancyBboxPatch
|
|
9
|
+
|
|
10
|
+
from .utils import inspect_df, plot_with_pdims_strategy
|
|
11
|
+
|
|
12
|
+
np.seterr(divide='ignore')
|
|
13
|
+
plt.rcParams.update({'font.size': 10})
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def configure_axes(ax: Axes,
|
|
17
|
+
x_values: List[int],
|
|
18
|
+
y_values: List[float],
|
|
19
|
+
xlabel: str,
|
|
20
|
+
title: str,
|
|
21
|
+
plotting_memory: bool = False,
|
|
22
|
+
memory_units: str = 'bytes'):
|
|
23
|
+
"""
|
|
24
|
+
Configure the axes for the plot.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
ax : Axes
|
|
29
|
+
The axes to configure.
|
|
30
|
+
x_values : List[int]
|
|
31
|
+
The x-axis values.
|
|
32
|
+
y_values : List[float]
|
|
33
|
+
The y-axis values.
|
|
34
|
+
xlabel : str
|
|
35
|
+
The label for the x-axis.
|
|
36
|
+
"""
|
|
37
|
+
ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
|
|
38
|
+
f2 = lambda x: np.log2(x)
|
|
39
|
+
g2 = lambda x: 2**x
|
|
40
|
+
ax.set_xlim([min(x_values), max(x_values)])
|
|
41
|
+
y_min, y_max = min(y_values) * 0.9, max(y_values) * 1.1
|
|
42
|
+
ax.set_title(title)
|
|
43
|
+
ax.set_ylim([y_min, y_max])
|
|
44
|
+
ax.set_xscale('function', functions=(f2, g2))
|
|
45
|
+
if not plotting_memory:
|
|
46
|
+
ax.set_yscale('symlog')
|
|
47
|
+
time_ticks = [
|
|
48
|
+
10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
|
|
49
|
+
int(np.ceil(np.log10(y_max))))
|
|
50
|
+
]
|
|
51
|
+
ax.set_yticks(time_ticks)
|
|
52
|
+
ax.set_xticks(x_values)
|
|
53
|
+
ax.set_xlabel(xlabel)
|
|
54
|
+
ax.set_ylabel(ylabel)
|
|
55
|
+
for x_value in x_values:
|
|
56
|
+
ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
|
|
57
|
+
ax.legend(loc='lower center',
|
|
58
|
+
bbox_to_anchor=(0.5, 0.05),
|
|
59
|
+
ncol=4,
|
|
60
|
+
prop={'size': 14})
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
64
|
+
fixed_sizes: List[int],
|
|
65
|
+
size_column: str,
|
|
66
|
+
fixed_column: str,
|
|
67
|
+
xlabel: str,
|
|
68
|
+
title: str,
|
|
69
|
+
figure_size: tuple = (6, 4),
|
|
70
|
+
output: Optional[str] = None,
|
|
71
|
+
dark_bg: bool = False,
|
|
72
|
+
print_decompositions: bool = False,
|
|
73
|
+
backends: List[str] = ['NCCL'],
|
|
74
|
+
precisions: List[str] = ['float32'],
|
|
75
|
+
functions: List[str] | None = None,
|
|
76
|
+
plot_columns: List[str] = ['mean_time'],
|
|
77
|
+
memory_units: str = 'bytes',
|
|
78
|
+
label_text: str = 'plot',
|
|
79
|
+
pdims_strategy: str = 'plot_fastest'):
|
|
80
|
+
"""
|
|
81
|
+
General scaling plot function based on the number of GPUs or data size.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
dataframes : Dict[str, pd.DataFrame]
|
|
86
|
+
Dictionary of method names to dataframes.
|
|
87
|
+
fixed_sizes : List[int]
|
|
88
|
+
List of fixed sizes (data or GPUs) to plot.
|
|
89
|
+
size_column : str
|
|
90
|
+
Column name for the size axis ('x' for weak scaling, 'gpus' for strong scaling).
|
|
91
|
+
fixed_column : str
|
|
92
|
+
Column name for the fixed axis ('gpus' for weak scaling, 'x' for strong scaling).
|
|
93
|
+
xlabel : str
|
|
94
|
+
Label for the x-axis.
|
|
95
|
+
figure_size : tuple, optional
|
|
96
|
+
Size of the figure, by default (6, 4).
|
|
97
|
+
output : Optional[str], optional
|
|
98
|
+
Output file to save the plot, by default None.
|
|
99
|
+
dark_bg : bool, optional
|
|
100
|
+
Whether to use dark background for the plot, by default False.
|
|
101
|
+
print_decompositions : bool, optional
|
|
102
|
+
Whether to print decompositions on the plot, by default False.
|
|
103
|
+
backends : Optional[List[str]], optional
|
|
104
|
+
List of backends to include, by default None.
|
|
105
|
+
pdims_strategy : str, optional
|
|
106
|
+
Strategy for plotting pdims ('plot_all' or 'plot_fastest'), by default 'plot_fastest'.
|
|
107
|
+
"""
|
|
108
|
+
if dark_bg:
|
|
109
|
+
plt.style.use('dark_background')
|
|
110
|
+
|
|
111
|
+
num_subplots = len(fixed_sizes)
|
|
112
|
+
num_rows = int(np.ceil(np.sqrt(num_subplots)))
|
|
113
|
+
num_cols = int(np.ceil(num_subplots / num_rows))
|
|
114
|
+
|
|
115
|
+
fig, axs = plt.subplots(num_rows, num_cols, figsize=figure_size)
|
|
116
|
+
if num_subplots > 1:
|
|
117
|
+
axs = axs.flatten()
|
|
118
|
+
else:
|
|
119
|
+
axs = [axs]
|
|
120
|
+
|
|
121
|
+
for i, fixed_size in enumerate(fixed_sizes):
|
|
122
|
+
ax: Axes = axs[i]
|
|
123
|
+
|
|
124
|
+
for method, df in dataframes.items():
|
|
125
|
+
x_values = []
|
|
126
|
+
y_values = []
|
|
127
|
+
|
|
128
|
+
filtered_method_df = df[df[fixed_column] == int(fixed_size)]
|
|
129
|
+
if filtered_method_df.empty:
|
|
130
|
+
continue
|
|
131
|
+
filtered_method_df = filtered_method_df.sort_values(
|
|
132
|
+
by=[size_column])
|
|
133
|
+
functions = pd.unique(filtered_method_df['function']
|
|
134
|
+
) if functions is None else functions
|
|
135
|
+
combinations = product(backends, precisions, functions,
|
|
136
|
+
plot_columns)
|
|
137
|
+
|
|
138
|
+
for backend, precision, function, plot_column in combinations:
|
|
139
|
+
filtered_params_df = filtered_method_df[
|
|
140
|
+
(filtered_method_df['backend'] == backend)
|
|
141
|
+
& (filtered_method_df['precision'] == precision) &
|
|
142
|
+
(filtered_method_df['function'] == function)]
|
|
143
|
+
if filtered_params_df.empty:
|
|
144
|
+
continue
|
|
145
|
+
x_vals, y_vals = plot_with_pdims_strategy(
|
|
146
|
+
ax, filtered_params_df, method, pdims_strategy,
|
|
147
|
+
print_decompositions, size_column, plot_column, label_text)
|
|
148
|
+
|
|
149
|
+
x_values.extend(x_vals)
|
|
150
|
+
y_values.extend(y_vals)
|
|
151
|
+
|
|
152
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
153
|
+
configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
|
|
154
|
+
plotting_memory, memory_units)
|
|
155
|
+
|
|
156
|
+
for i in range(num_subplots, num_rows * num_cols):
|
|
157
|
+
fig.delaxes(axs[i])
|
|
158
|
+
|
|
159
|
+
fig.tight_layout()
|
|
160
|
+
rect = FancyBboxPatch((0.1, 0.1),
|
|
161
|
+
0.8,
|
|
162
|
+
0.8,
|
|
163
|
+
boxstyle="round,pad=0.02",
|
|
164
|
+
ec="black",
|
|
165
|
+
fc="none")
|
|
166
|
+
fig.patches.append(rect)
|
|
167
|
+
if output is None:
|
|
168
|
+
plt.show()
|
|
169
|
+
else:
|
|
170
|
+
plt.savefig(output, bbox_inches='tight', transparent=False)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def plot_strong_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
174
|
+
fixed_data_size: List[int],
|
|
175
|
+
figure_size: tuple = (6, 4),
|
|
176
|
+
output: Optional[str] = None,
|
|
177
|
+
dark_bg: bool = False,
|
|
178
|
+
print_decompositions: bool = False,
|
|
179
|
+
backends: List[str] = ['NCCL'],
|
|
180
|
+
precisions: List[str] = ['float32'],
|
|
181
|
+
functions: List[str] | None = None,
|
|
182
|
+
plot_columns: List[str] = ['mean_time'],
|
|
183
|
+
memory_units: str = 'bytes',
|
|
184
|
+
label_text: str = 'plot',
|
|
185
|
+
pdims_strategy: str = 'plot_fastest'):
|
|
186
|
+
"""
|
|
187
|
+
Plot strong scaling based on the number of GPUs.
|
|
188
|
+
"""
|
|
189
|
+
plot_scaling(dataframes, fixed_data_size, 'gpus', 'x', 'Number of GPUs',
|
|
190
|
+
'Data size', figure_size, output, dark_bg,
|
|
191
|
+
print_decompositions, backends, precisions, functions,
|
|
192
|
+
plot_columns, memory_units, label_text, pdims_strategy)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def plot_weak_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
196
|
+
fixed_gpu_size: List[int],
|
|
197
|
+
figure_size: tuple = (6, 4),
|
|
198
|
+
output: Optional[str] = None,
|
|
199
|
+
dark_bg: bool = False,
|
|
200
|
+
print_decompositions: bool = False,
|
|
201
|
+
backends: List[str] = ['NCCL'],
|
|
202
|
+
precisions: List[str] = ['float32'],
|
|
203
|
+
functions: List[str] | None = None,
|
|
204
|
+
plot_columns: List[str] = ['mean_time'],
|
|
205
|
+
memory_units: str = 'bytes',
|
|
206
|
+
label_text: str = 'plot',
|
|
207
|
+
pdims_strategy: str = 'plot_fastest'):
|
|
208
|
+
"""
|
|
209
|
+
Plot weak scaling based on the data size.
|
|
210
|
+
"""
|
|
211
|
+
plot_scaling(dataframes, fixed_gpu_size, 'x', 'gpus', 'Data size',
|
|
212
|
+
'Number of GPUs', figure_size, output, dark_bg,
|
|
213
|
+
print_decompositions, backends, precisions, functions,
|
|
214
|
+
plot_columns, memory_units, label_text, pdims_strategy)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any, Callable, List
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from jax import make_jaxpr
|
|
10
|
+
from jax.experimental.shard_map import shard_map
|
|
11
|
+
from jax.sharding import Mesh, NamedSharding
|
|
12
|
+
from jax.sharding import PartitionSpec as P
|
|
13
|
+
from tabulate import tabulate
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Timer:
|
|
17
|
+
|
|
18
|
+
def __init__(self, save_jaxpr=False):
|
|
19
|
+
self.jit_time = None
|
|
20
|
+
self.times = []
|
|
21
|
+
self.profiling_data = {}
|
|
22
|
+
self.compiled_code = {}
|
|
23
|
+
self.save_jaxpr = save_jaxpr
|
|
24
|
+
|
|
25
|
+
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
26
|
+
start = time.perf_counter()
|
|
27
|
+
out = jax.jit(fun)(*args)
|
|
28
|
+
if ndarray_arg is None:
|
|
29
|
+
out.block_until_ready()
|
|
30
|
+
else:
|
|
31
|
+
out[ndarray_arg].block_until_ready()
|
|
32
|
+
end = time.perf_counter()
|
|
33
|
+
self.jit_time = (end - start) * 1e3
|
|
34
|
+
|
|
35
|
+
if self.save_jaxpr:
|
|
36
|
+
jaxpr = make_jaxpr(fun)(*args)
|
|
37
|
+
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
38
|
+
|
|
39
|
+
lowered = jax.jit(fun).lower(*args)
|
|
40
|
+
compiled = lowered.compile()
|
|
41
|
+
memory_analysis = compiled.memory_analysis()
|
|
42
|
+
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
43
|
+
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
44
|
+
self.profiling_data["FLOPS"] = compiled.cost_analysis()[0]['flops']
|
|
45
|
+
self.profiling_data[
|
|
46
|
+
"generated_code"] = memory_analysis.generated_code_size_in_bytes
|
|
47
|
+
self.profiling_data[
|
|
48
|
+
"argument_size"] = memory_analysis.argument_size_in_bytes
|
|
49
|
+
self.profiling_data[
|
|
50
|
+
"output_size"] = memory_analysis.output_size_in_bytes
|
|
51
|
+
self.profiling_data["temp_size"] = memory_analysis.temp_size_in_bytes
|
|
52
|
+
|
|
53
|
+
return out
|
|
54
|
+
|
|
55
|
+
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
56
|
+
start = time.perf_counter()
|
|
57
|
+
out = fun(*args)
|
|
58
|
+
if ndarray_arg is None:
|
|
59
|
+
out.block_until_ready()
|
|
60
|
+
else:
|
|
61
|
+
out[ndarray_arg].block_until_ready()
|
|
62
|
+
end = time.perf_counter()
|
|
63
|
+
self.times.append((end - start) * 1e3)
|
|
64
|
+
return out
|
|
65
|
+
|
|
66
|
+
def _get_mean_times(self, times_array: jnp.ndarray,
|
|
67
|
+
sharding: NamedSharding):
|
|
68
|
+
mesh = sharding.mesh
|
|
69
|
+
specs = sharding.spec
|
|
70
|
+
valid_letters = [letter for letter in specs if letter is not None]
|
|
71
|
+
assert len(valid_letters
|
|
72
|
+
) > 0, "Sharding was provided but with no partition specs"
|
|
73
|
+
|
|
74
|
+
@partial(shard_map,
|
|
75
|
+
mesh=mesh,
|
|
76
|
+
in_specs=specs,
|
|
77
|
+
out_specs=P(),
|
|
78
|
+
check_rep=False)
|
|
79
|
+
def get_mean_times(times):
|
|
80
|
+
mean = jax.lax.pmean(times, axis_name=valid_letters[0])
|
|
81
|
+
for axis_name in valid_letters[1:]:
|
|
82
|
+
mean = jax.lax.pmean(mean, axis_name=axis_name)
|
|
83
|
+
return mean
|
|
84
|
+
|
|
85
|
+
times_array = get_mean_times(times_array)
|
|
86
|
+
times_array.block_until_ready()
|
|
87
|
+
return times_array
|
|
88
|
+
|
|
89
|
+
def report(self,
|
|
90
|
+
csv_filename: str,
|
|
91
|
+
function: str,
|
|
92
|
+
precision: str,
|
|
93
|
+
x: int,
|
|
94
|
+
y: int,
|
|
95
|
+
z: int,
|
|
96
|
+
px: int,
|
|
97
|
+
py: int,
|
|
98
|
+
backend: str,
|
|
99
|
+
nodes: int,
|
|
100
|
+
sharding: NamedSharding | None = None,
|
|
101
|
+
md_filename: str | None = None,
|
|
102
|
+
extra_info: dict = {}):
|
|
103
|
+
times_array = jnp.array(self.times)
|
|
104
|
+
|
|
105
|
+
if md_filename is None:
|
|
106
|
+
dirname, filename = os.path.dirname(csv_filename), os.path.splitext(os.path.basename(csv_filename))[0]
|
|
107
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
108
|
+
print(f"report_folder: {report_folder} csv_filename: {csv_filename}")
|
|
109
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
110
|
+
md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
111
|
+
|
|
112
|
+
if sharding is not None:
|
|
113
|
+
times_array = self._get_mean_times(times_array, sharding)
|
|
114
|
+
|
|
115
|
+
times_array = np.array(times_array)
|
|
116
|
+
min_time = np.min(times_array)
|
|
117
|
+
max_time = np.max(times_array)
|
|
118
|
+
mean_time = np.mean(times_array)
|
|
119
|
+
std_time = np.std(times_array)
|
|
120
|
+
last_time = times_array[-1]
|
|
121
|
+
|
|
122
|
+
flops = self.profiling_data["FLOPS"]
|
|
123
|
+
generated_code = self.profiling_data["generated_code"]
|
|
124
|
+
argument_size = self.profiling_data["argument_size"]
|
|
125
|
+
output_size = self.profiling_data["output_size"]
|
|
126
|
+
temp_size = self.profiling_data["temp_size"]
|
|
127
|
+
|
|
128
|
+
csv_line = (
|
|
129
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
130
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
131
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
with open(csv_filename, 'a') as f:
|
|
135
|
+
f.write(csv_line)
|
|
136
|
+
|
|
137
|
+
param_dict = {
|
|
138
|
+
"Function": function,
|
|
139
|
+
"Precision": precision,
|
|
140
|
+
"X": x,
|
|
141
|
+
"Y": y,
|
|
142
|
+
"Z": z,
|
|
143
|
+
"PX": px,
|
|
144
|
+
"PY": py,
|
|
145
|
+
"Backend": backend,
|
|
146
|
+
"Nodes": nodes,
|
|
147
|
+
}
|
|
148
|
+
param_dict.update(extra_info)
|
|
149
|
+
profiling_result = {
|
|
150
|
+
"JIT Time": self.jit_time,
|
|
151
|
+
"Min Time": min_time,
|
|
152
|
+
"Max Time": max_time,
|
|
153
|
+
"Mean Time": mean_time,
|
|
154
|
+
"Std Time": std_time,
|
|
155
|
+
"Last Time": last_time,
|
|
156
|
+
"Generated Code": generated_code,
|
|
157
|
+
"Argument Size": argument_size,
|
|
158
|
+
"Output Size": output_size,
|
|
159
|
+
"Temporary Size": temp_size,
|
|
160
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
with open(md_filename, 'w') as f:
|
|
164
|
+
f.write(f"# Reporting for {function}\n")
|
|
165
|
+
f.write(f"## Parameters\n")
|
|
166
|
+
f.write(tabulate(param_dict.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
|
|
167
|
+
f.write("\n---\n")
|
|
168
|
+
f.write(f"## Profiling Data\n")
|
|
169
|
+
f.write(tabulate(profiling_result.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
|
|
170
|
+
f.write("\n---\n")
|
|
171
|
+
f.write(f"## Compiled Code\n")
|
|
172
|
+
f.write(f"```hlo\n")
|
|
173
|
+
f.write(self.compiled_code["COMPILED"])
|
|
174
|
+
f.write(f"\n```\n")
|
|
175
|
+
f.write("\n---\n")
|
|
176
|
+
f.write(f"## Lowered Code\n")
|
|
177
|
+
f.write(f"```hlo\n")
|
|
178
|
+
f.write(self.compiled_code["LOWERED"])
|
|
179
|
+
f.write(f"\n```\n")
|
|
180
|
+
f.write("\n---\n")
|
|
181
|
+
if self.save_jaxpr:
|
|
182
|
+
f.write(f"## JAXPR\n")
|
|
183
|
+
f.write(f"```haskel\n")
|
|
184
|
+
f.write(self.compiled_code["JAXPR"])
|
|
185
|
+
f.write(f"\n```\n")
|