jax-hpc-profiler 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/create_argparse.py +147 -111
- jax_hpc_profiler/main.py +47 -45
- jax_hpc_profiler/plotting.py +183 -90
- jax_hpc_profiler/timer.py +122 -82
- jax_hpc_profiler/utils.py +15 -2
- {jax_hpc_profiler-0.2.7.dist-info → jax_hpc_profiler-0.2.9.dist-info}/METADATA +2 -2
- jax_hpc_profiler-0.2.9.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.7.dist-info → jax_hpc_profiler-0.2.9.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.7.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.7.dist-info → jax_hpc_profiler-0.2.9.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.7.dist-info → jax_hpc_profiler-0.2.9.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.7.dist-info → jax_hpc_profiler-0.2.9.dist-info}/top_level.txt +0 -0
|
@@ -11,151 +11,187 @@ def create_argparser():
|
|
|
11
11
|
Parsed and validated arguments.
|
|
12
12
|
"""
|
|
13
13
|
parser = argparse.ArgumentParser(
|
|
14
|
-
description=
|
|
14
|
+
description="HPC Plotter for benchmarking data")
|
|
15
15
|
|
|
16
16
|
# Group for concatenation to ensure mutually exclusive behavior
|
|
17
|
-
subparsers = parser.add_subparsers(dest=
|
|
17
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
18
18
|
|
|
19
|
-
concat_parser = subparsers.add_parser(
|
|
20
|
-
help=
|
|
21
|
-
concat_parser.add_argument(
|
|
19
|
+
concat_parser = subparsers.add_parser("concat",
|
|
20
|
+
help="Concatenate CSV files")
|
|
21
|
+
concat_parser.add_argument("input",
|
|
22
22
|
type=str,
|
|
23
|
-
help=
|
|
24
|
-
concat_parser.add_argument(
|
|
23
|
+
help="Input directory for concatenation")
|
|
24
|
+
concat_parser.add_argument("output",
|
|
25
25
|
type=str,
|
|
26
|
-
help=
|
|
26
|
+
help="Output directory for concatenation")
|
|
27
27
|
|
|
28
28
|
# Arguments for plotting
|
|
29
|
-
plot_parser = subparsers.add_parser(
|
|
30
|
-
plot_parser.add_argument(
|
|
31
|
-
|
|
32
|
-
nargs=
|
|
33
|
-
help=
|
|
29
|
+
plot_parser = subparsers.add_parser("plot", help="Plot CSV data")
|
|
30
|
+
plot_parser.add_argument("-f",
|
|
31
|
+
"--csv_files",
|
|
32
|
+
nargs="+",
|
|
33
|
+
help="List of CSV files to plot",
|
|
34
34
|
required=True)
|
|
35
|
-
plot_parser.add_argument(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
35
|
+
plot_parser.add_argument(
|
|
36
|
+
"-g",
|
|
37
|
+
"--gpus",
|
|
38
|
+
nargs="*",
|
|
39
|
+
type=int,
|
|
40
|
+
help="List of number of GPUs to plot",
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
plot_parser.add_argument(
|
|
44
|
+
"-d",
|
|
45
|
+
"--data_size",
|
|
46
|
+
nargs="*",
|
|
47
|
+
type=int,
|
|
48
|
+
help="List of data sizes to plot",
|
|
49
|
+
default=None,
|
|
50
|
+
)
|
|
45
51
|
|
|
46
52
|
# pdims related arguments
|
|
47
|
-
plot_parser.add_argument('-fd',
|
|
48
|
-
'--filter_pdims',
|
|
49
|
-
nargs='*',
|
|
50
|
-
help='List of pdims to filter, e.g., 1x4 2x2 4x8')
|
|
51
53
|
plot_parser.add_argument(
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
default=
|
|
57
|
-
|
|
54
|
+
"-fd",
|
|
55
|
+
"--filter_pdims",
|
|
56
|
+
nargs="*",
|
|
57
|
+
help="List of pdims to filter, e.g., 1x4 2x2 4x8",
|
|
58
|
+
default=None,
|
|
59
|
+
)
|
|
60
|
+
plot_parser.add_argument(
|
|
61
|
+
"-ps",
|
|
62
|
+
"--pdim_strategy",
|
|
63
|
+
choices=["plot_all", "plot_fastest", "slab_yz", "slab_xy", "pencils"],
|
|
64
|
+
nargs="*",
|
|
65
|
+
default=["plot_fastest"],
|
|
66
|
+
help="Strategy for plotting pdims",
|
|
67
|
+
)
|
|
58
68
|
|
|
59
69
|
# Function and precision related arguments
|
|
60
70
|
plot_parser.add_argument(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
choices=[
|
|
64
|
-
default=[
|
|
65
|
-
nargs=
|
|
66
|
-
help=
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
+
"-pr",
|
|
72
|
+
"--precision",
|
|
73
|
+
choices=["float32", "float64"],
|
|
74
|
+
default=["float32", "float64"],
|
|
75
|
+
nargs="*",
|
|
76
|
+
help="Precision to filter by (float32 or float64)",
|
|
77
|
+
)
|
|
78
|
+
plot_parser.add_argument(
|
|
79
|
+
"-fn",
|
|
80
|
+
"--function_name",
|
|
81
|
+
nargs="+",
|
|
82
|
+
help="Function names to filter",
|
|
83
|
+
default=None,
|
|
84
|
+
)
|
|
71
85
|
|
|
72
86
|
# Time or memory related arguments
|
|
73
87
|
plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
|
|
74
|
-
plotting_group.add_argument(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
88
|
+
plotting_group.add_argument(
|
|
89
|
+
"-pt",
|
|
90
|
+
"--plot_times",
|
|
91
|
+
nargs="*",
|
|
92
|
+
choices=[
|
|
93
|
+
"jit_time",
|
|
94
|
+
"min_time",
|
|
95
|
+
"max_time",
|
|
96
|
+
"mean_time",
|
|
97
|
+
"std_time",
|
|
98
|
+
"last_time",
|
|
99
|
+
],
|
|
100
|
+
help="Time columns to plot",
|
|
101
|
+
)
|
|
102
|
+
plotting_group.add_argument(
|
|
103
|
+
"-pm",
|
|
104
|
+
"--plot_memory",
|
|
105
|
+
nargs="*",
|
|
106
|
+
choices=[
|
|
107
|
+
"generated_code", "argument_size", "output_size", "temp_size"
|
|
108
|
+
],
|
|
109
|
+
help="Memory columns to plot",
|
|
110
|
+
)
|
|
111
|
+
plot_parser.add_argument(
|
|
112
|
+
"-mu",
|
|
113
|
+
"--memory_units",
|
|
114
|
+
default="GB",
|
|
115
|
+
help="Memory units to plot (KB, MB, GB, TB)",
|
|
116
|
+
)
|
|
94
117
|
|
|
95
118
|
# Plot customization arguments
|
|
96
|
-
plot_parser.add_argument(
|
|
97
|
-
|
|
119
|
+
plot_parser.add_argument("-fs",
|
|
120
|
+
"--figure_size",
|
|
98
121
|
nargs=2,
|
|
99
122
|
type=int,
|
|
100
|
-
help=
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
123
|
+
help="Figure size",
|
|
124
|
+
default=(10, 6))
|
|
125
|
+
plot_parser.add_argument("-o",
|
|
126
|
+
"--output",
|
|
127
|
+
help="Output file (if none then only show plot)",
|
|
104
128
|
default=None)
|
|
105
|
-
plot_parser.add_argument(
|
|
106
|
-
|
|
107
|
-
action=
|
|
108
|
-
help=
|
|
109
|
-
plot_parser.add_argument(
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
129
|
+
plot_parser.add_argument("-db",
|
|
130
|
+
"--dark_bg",
|
|
131
|
+
action="store_true",
|
|
132
|
+
help="Use dark background for plotting")
|
|
133
|
+
plot_parser.add_argument(
|
|
134
|
+
"-pd",
|
|
135
|
+
"--print_decompositions",
|
|
136
|
+
action="store_true",
|
|
137
|
+
help="Print decompositions on plot",
|
|
138
|
+
)
|
|
113
139
|
|
|
114
140
|
# Backend related arguments
|
|
115
|
-
plot_parser.add_argument(
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
141
|
+
plot_parser.add_argument(
|
|
142
|
+
"-b",
|
|
143
|
+
"--backends",
|
|
144
|
+
nargs="*",
|
|
145
|
+
default=["MPI", "NCCL", "MPI4JAX"],
|
|
146
|
+
help="List of backends to include",
|
|
147
|
+
)
|
|
120
148
|
|
|
121
149
|
# Scaling type argument
|
|
122
|
-
plot_parser.add_argument(
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
150
|
+
plot_parser.add_argument(
|
|
151
|
+
"-sc",
|
|
152
|
+
"--scaling",
|
|
153
|
+
choices=["Weak", "Strong", "w", "s"],
|
|
154
|
+
required=True,
|
|
155
|
+
help="Scaling type (Weak or Strong)",
|
|
156
|
+
)
|
|
127
157
|
|
|
128
158
|
# Label customization argument
|
|
129
159
|
plot_parser.add_argument(
|
|
130
|
-
|
|
131
|
-
|
|
160
|
+
"-l",
|
|
161
|
+
"--label_text",
|
|
132
162
|
type=str,
|
|
133
163
|
help=
|
|
134
|
-
|
|
135
|
-
|
|
164
|
+
("Custom label for the plot. You can use placeholders: %%decomposition%% "
|
|
165
|
+
"(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), "
|
|
166
|
+
"%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)"
|
|
167
|
+
),
|
|
168
|
+
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
169
|
+
)
|
|
136
170
|
|
|
137
|
-
subparsers.add_parser(
|
|
171
|
+
subparsers.add_parser("label_help", help="Label customization help")
|
|
138
172
|
|
|
139
173
|
args = parser.parse_args()
|
|
140
|
-
|
|
174
|
+
|
|
141
175
|
# if command was plot, then check if pdim_strategy is validat
|
|
142
|
-
if args.command ==
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
176
|
+
if args.command == "plot":
|
|
177
|
+
if "plot_all" in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
178
|
+
print(
|
|
179
|
+
"Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
|
|
180
|
+
)
|
|
181
|
+
args.pdim_strategy = ["plot_all"]
|
|
182
|
+
|
|
183
|
+
if "plot_fastest" in args.pdim_strategy and len(
|
|
184
|
+
args.pdim_strategy) > 1:
|
|
185
|
+
print(
|
|
186
|
+
"Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
|
|
187
|
+
)
|
|
188
|
+
args.pdim_strategy = ["plot_fastest"]
|
|
189
|
+
if args.plot_times is not None:
|
|
190
|
+
args.plot_columns = args.plot_times
|
|
191
|
+
elif args.plot_memory is not None:
|
|
192
|
+
args.plot_columns = args.plot_memory
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
"Either plot_times or plot_memory should be provided")
|
|
160
196
|
|
|
161
197
|
return args
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import sys
|
|
2
|
+
from typing import List, Optional
|
|
2
3
|
|
|
3
4
|
from .create_argparse import create_argparser
|
|
4
5
|
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
@@ -8,55 +9,56 @@ from .utils import clean_up_csv, concatenate_csvs
|
|
|
8
9
|
def main():
|
|
9
10
|
args = create_argparser()
|
|
10
11
|
|
|
11
|
-
if args.command ==
|
|
12
|
+
if args.command == "concat":
|
|
12
13
|
input_dir, output_dir = args.input, args.output
|
|
13
14
|
concatenate_csvs(input_dir, output_dir)
|
|
14
|
-
elif args.command ==
|
|
15
|
+
elif args.command == "label_help":
|
|
15
16
|
print(f"Customize the label text for the plot. using these commands.")
|
|
16
|
-
print(
|
|
17
|
-
print(
|
|
18
|
-
print(
|
|
19
|
-
print(
|
|
20
|
-
print(
|
|
21
|
-
print(
|
|
22
|
-
print(
|
|
23
|
-
elif args.command ==
|
|
24
|
-
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
25
|
-
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
26
|
-
args.data_size, args.filter_pdims, args.pdim_strategy,
|
|
27
|
-
args.backends, args.memory_units)
|
|
28
|
-
if len(dataframes) == 0:
|
|
29
|
-
print(f"No dataframes found for the given arguments. Exiting...")
|
|
30
|
-
sys.exit(1)
|
|
31
|
-
print(
|
|
32
|
-
f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
|
|
33
|
-
)
|
|
34
|
-
# filter back the requested data sizes and gpus
|
|
35
|
-
|
|
36
|
-
args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
|
|
37
|
-
args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
|
|
17
|
+
print(" -- %m% or %methodname%: method name")
|
|
18
|
+
print(" -- %f% or %function%: function name")
|
|
19
|
+
print(" -- %pn% or %plot_name%: plot name")
|
|
20
|
+
print(" -- %pr% or %precision%: precision")
|
|
21
|
+
print(" -- %b% or %backend%: backend")
|
|
22
|
+
print(" -- %p% or %pdims%: pdims")
|
|
23
|
+
print(" -- %n% or %node%: node")
|
|
24
|
+
elif args.command == "plot":
|
|
38
25
|
|
|
39
|
-
if
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
26
|
+
if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
|
|
27
|
+
plot_weak_scaling(
|
|
28
|
+
args.csv_files,
|
|
29
|
+
args.gpus,
|
|
30
|
+
args.data_size,
|
|
31
|
+
args.function_name,
|
|
32
|
+
args.precision,
|
|
33
|
+
args.filter_pdims,
|
|
34
|
+
args.pdim_strategy,
|
|
35
|
+
args.print_decompositions,
|
|
36
|
+
args.backends,
|
|
37
|
+
args.plot_columns,
|
|
38
|
+
args.memory_units,
|
|
39
|
+
args.label_text,
|
|
40
|
+
args.figure_size,
|
|
41
|
+
args.dark_bg,
|
|
42
|
+
args.output,
|
|
43
|
+
)
|
|
44
|
+
elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
|
|
45
|
+
plot_strong_scaling(
|
|
46
|
+
args.csv_files,
|
|
47
|
+
args.gpus,
|
|
48
|
+
args.data_size,
|
|
49
|
+
args.function_name,
|
|
50
|
+
args.precision,
|
|
51
|
+
args.filter_pdims,
|
|
52
|
+
args.pdim_strategy,
|
|
53
|
+
args.print_decompositions,
|
|
54
|
+
args.backends,
|
|
55
|
+
args.plot_columns,
|
|
56
|
+
args.memory_units,
|
|
57
|
+
args.label_text,
|
|
58
|
+
args.figure_size,
|
|
59
|
+
args.dark_bg,
|
|
60
|
+
args.output,
|
|
61
|
+
)
|
|
60
62
|
|
|
61
63
|
|
|
62
64
|
if __name__ == "__main__":
|
jax_hpc_profiler/plotting.py
CHANGED
|
@@ -4,22 +4,24 @@ from typing import Dict, List, Optional
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
import seaborn as sns
|
|
7
8
|
from matplotlib.axes import Axes
|
|
8
9
|
from matplotlib.patches import FancyBboxPatch
|
|
9
10
|
|
|
10
|
-
from .utils import inspect_df, plot_with_pdims_strategy
|
|
11
|
+
from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
|
|
11
12
|
|
|
12
|
-
np.seterr(divide=
|
|
13
|
-
plt.rcParams.update({'font.size': 15})
|
|
13
|
+
np.seterr(divide="ignore")
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def configure_axes(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
16
|
+
def configure_axes(
|
|
17
|
+
ax: Axes,
|
|
18
|
+
x_values: List[int],
|
|
19
|
+
y_values: List[float],
|
|
20
|
+
xlabel: str,
|
|
21
|
+
title: str,
|
|
22
|
+
plotting_memory: bool = False,
|
|
23
|
+
memory_units: str = "bytes",
|
|
24
|
+
):
|
|
23
25
|
"""
|
|
24
26
|
Configure the axes for the plot.
|
|
25
27
|
|
|
@@ -34,16 +36,17 @@ def configure_axes(ax: Axes,
|
|
|
34
36
|
xlabel : str
|
|
35
37
|
The label for the x-axis.
|
|
36
38
|
"""
|
|
37
|
-
ylabel =
|
|
39
|
+
ylabel = ("Time (milliseconds)"
|
|
40
|
+
if not plotting_memory else f"Memory ({memory_units})")
|
|
38
41
|
f2 = lambda x: np.log2(x)
|
|
39
42
|
g2 = lambda x: 2**x
|
|
40
43
|
ax.set_xlim([min(x_values), max(x_values)])
|
|
41
44
|
y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
|
|
42
45
|
ax.set_title(title)
|
|
43
46
|
ax.set_ylim([y_min, y_max])
|
|
44
|
-
ax.set_xscale(
|
|
47
|
+
ax.set_xscale("function", functions=(f2, g2))
|
|
45
48
|
if not plotting_memory:
|
|
46
|
-
ax.set_yscale(
|
|
49
|
+
ax.set_yscale("symlog")
|
|
47
50
|
time_ticks = [
|
|
48
51
|
10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
|
|
49
52
|
int(np.ceil(np.log10(y_max))))
|
|
@@ -53,31 +56,35 @@ def configure_axes(ax: Axes,
|
|
|
53
56
|
ax.set_xlabel(xlabel)
|
|
54
57
|
ax.set_ylabel(ylabel)
|
|
55
58
|
for x_value in x_values:
|
|
56
|
-
ax.axvline(x=x_value, color=
|
|
57
|
-
ax.legend(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
59
|
+
ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
|
|
60
|
+
ax.legend(
|
|
61
|
+
loc="lower center",
|
|
62
|
+
bbox_to_anchor=(0.5, 0.05),
|
|
63
|
+
ncol=4,
|
|
64
|
+
fontsize="x-large",
|
|
65
|
+
prop={"size": 14},
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def plot_scaling(
|
|
70
|
+
dataframes: Dict[str, pd.DataFrame],
|
|
71
|
+
fixed_sizes: List[int],
|
|
72
|
+
size_column: str,
|
|
73
|
+
fixed_column: str,
|
|
74
|
+
xlabel: str,
|
|
75
|
+
title: str,
|
|
76
|
+
figure_size: tuple = (6, 4),
|
|
77
|
+
output: Optional[str] = None,
|
|
78
|
+
dark_bg: bool = False,
|
|
79
|
+
print_decompositions: bool = False,
|
|
80
|
+
backends: List[str] = ["NCCL"],
|
|
81
|
+
precisions: List[str] = ["float32"],
|
|
82
|
+
functions: List[str] | None = None,
|
|
83
|
+
plot_columns: List[str] = ["mean_time"],
|
|
84
|
+
memory_units: str = "bytes",
|
|
85
|
+
label_text: str = "plot",
|
|
86
|
+
pdims_strategy: List[str] = ["plot_fastest"],
|
|
87
|
+
):
|
|
81
88
|
"""
|
|
82
89
|
General scaling plot function based on the number of GPUs or data size.
|
|
83
90
|
|
|
@@ -106,8 +113,9 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
106
113
|
pdims_strategy : str, optional
|
|
107
114
|
Strategy for plotting pdims ('plot_all' or 'plot_fastest'), by default 'plot_fastest'.
|
|
108
115
|
"""
|
|
116
|
+
|
|
109
117
|
if dark_bg:
|
|
110
|
-
plt.style.use(
|
|
118
|
+
plt.style.use("dark_background")
|
|
111
119
|
|
|
112
120
|
num_subplots = len(fixed_sizes)
|
|
113
121
|
num_rows = int(np.ceil(np.sqrt(num_subplots)))
|
|
@@ -122,37 +130,53 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
122
130
|
for i, fixed_size in enumerate(fixed_sizes):
|
|
123
131
|
ax: Axes = axs[i]
|
|
124
132
|
|
|
133
|
+
x_values = []
|
|
134
|
+
y_values = []
|
|
125
135
|
for method, df in dataframes.items():
|
|
126
|
-
x_values = []
|
|
127
|
-
y_values = []
|
|
128
136
|
|
|
129
137
|
filtered_method_df = df[df[fixed_column] == int(fixed_size)]
|
|
130
138
|
if filtered_method_df.empty:
|
|
131
139
|
continue
|
|
132
140
|
filtered_method_df = filtered_method_df.sort_values(
|
|
133
141
|
by=[size_column])
|
|
134
|
-
functions = pd.unique(filtered_method_df[
|
|
135
|
-
|
|
142
|
+
functions = (pd.unique(filtered_method_df["function"])
|
|
143
|
+
if functions is None else functions)
|
|
136
144
|
combinations = product(backends, precisions, functions,
|
|
137
145
|
plot_columns)
|
|
138
146
|
|
|
139
147
|
for backend, precision, function, plot_column in combinations:
|
|
148
|
+
|
|
140
149
|
filtered_params_df = filtered_method_df[
|
|
141
|
-
(filtered_method_df[
|
|
142
|
-
& (filtered_method_df[
|
|
143
|
-
(filtered_method_df[
|
|
150
|
+
(filtered_method_df["backend"] == backend)
|
|
151
|
+
& (filtered_method_df["precision"] == precision)
|
|
152
|
+
& (filtered_method_df["function"] == function)]
|
|
144
153
|
if filtered_params_df.empty:
|
|
145
154
|
continue
|
|
146
155
|
x_vals, y_vals = plot_with_pdims_strategy(
|
|
147
|
-
ax,
|
|
148
|
-
|
|
156
|
+
ax,
|
|
157
|
+
filtered_params_df,
|
|
158
|
+
method,
|
|
159
|
+
pdims_strategy,
|
|
160
|
+
print_decompositions,
|
|
161
|
+
size_column,
|
|
162
|
+
plot_column,
|
|
163
|
+
label_text,
|
|
164
|
+
)
|
|
149
165
|
|
|
150
166
|
x_values.extend(x_vals)
|
|
151
167
|
y_values.extend(y_vals)
|
|
152
168
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
169
|
+
if len(x_values) != 0:
|
|
170
|
+
plotting_memory = "time" not in plot_columns[0].lower()
|
|
171
|
+
configure_axes(
|
|
172
|
+
ax,
|
|
173
|
+
x_values,
|
|
174
|
+
y_values,
|
|
175
|
+
f"{title} {fixed_size}",
|
|
176
|
+
xlabel,
|
|
177
|
+
plotting_memory,
|
|
178
|
+
memory_units,
|
|
179
|
+
)
|
|
156
180
|
|
|
157
181
|
for i in range(num_subplots, num_rows * num_cols):
|
|
158
182
|
fig.delaxes(axs[i])
|
|
@@ -168,48 +192,117 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
168
192
|
if output is None:
|
|
169
193
|
plt.show()
|
|
170
194
|
else:
|
|
171
|
-
plt.savefig(output, bbox_inches=
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def plot_strong_scaling(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
195
|
+
plt.savefig(output, bbox_inches="tight", transparent=True)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def plot_strong_scaling(
|
|
199
|
+
csv_files: List[str],
|
|
200
|
+
fixed_gpu_size: Optional[List[int]] = None,
|
|
201
|
+
fixed_data_size: Optional[List[int]] = None,
|
|
202
|
+
functions: List[str] | None = None,
|
|
203
|
+
precisions: List[str] = ["float32"],
|
|
204
|
+
pdims: Optional[List[str]] = None,
|
|
205
|
+
pdims_strategy: List[str] = ["plot_fastest"],
|
|
206
|
+
print_decompositions: bool = False,
|
|
207
|
+
backends: List[str] = ["NCCL"],
|
|
208
|
+
plot_columns: List[str] = ["mean_time"],
|
|
209
|
+
memory_units: str = "bytes",
|
|
210
|
+
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
211
|
+
figure_size: tuple = (6, 4),
|
|
212
|
+
dark_bg: bool = False,
|
|
213
|
+
output: Optional[str] = None,
|
|
214
|
+
):
|
|
187
215
|
"""
|
|
188
216
|
Plot strong scaling based on the number of GPUs.
|
|
189
217
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
218
|
+
|
|
219
|
+
dataframes, _, available_data_sizes = clean_up_csv(
|
|
220
|
+
csv_files,
|
|
221
|
+
precisions,
|
|
222
|
+
functions,
|
|
223
|
+
fixed_gpu_size,
|
|
224
|
+
fixed_data_size,
|
|
225
|
+
pdims,
|
|
226
|
+
pdims_strategy,
|
|
227
|
+
backends,
|
|
228
|
+
memory_units,
|
|
229
|
+
)
|
|
230
|
+
if len(dataframes) == 0:
|
|
231
|
+
print(f"No dataframes found for the given arguments. Exiting...")
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
plot_scaling(
|
|
235
|
+
dataframes,
|
|
236
|
+
available_data_sizes,
|
|
237
|
+
"gpus",
|
|
238
|
+
"x",
|
|
239
|
+
"Number of GPUs",
|
|
240
|
+
"Data size",
|
|
241
|
+
figure_size,
|
|
242
|
+
output,
|
|
243
|
+
dark_bg,
|
|
244
|
+
print_decompositions,
|
|
245
|
+
backends,
|
|
246
|
+
precisions,
|
|
247
|
+
functions,
|
|
248
|
+
plot_columns,
|
|
249
|
+
memory_units,
|
|
250
|
+
label_text,
|
|
251
|
+
pdims_strategy,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def plot_weak_scaling(
|
|
256
|
+
csv_files: List[str],
|
|
257
|
+
fixed_gpu_size: Optional[List[int]] = None,
|
|
258
|
+
fixed_data_size: Optional[List[int]] = None,
|
|
259
|
+
functions: List[str] | None = None,
|
|
260
|
+
precisions: List[str] = ["float32"],
|
|
261
|
+
pdims: Optional[List[str]] = None,
|
|
262
|
+
pdims_strategy: List[str] = ["plot_fastest"],
|
|
263
|
+
print_decompositions: bool = False,
|
|
264
|
+
backends: List[str] = ["NCCL"],
|
|
265
|
+
plot_columns: List[str] = ["mean_time"],
|
|
266
|
+
memory_units: str = "bytes",
|
|
267
|
+
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
268
|
+
figure_size: tuple = (6, 4),
|
|
269
|
+
dark_bg: bool = False,
|
|
270
|
+
output: Optional[str] = None,
|
|
271
|
+
):
|
|
209
272
|
"""
|
|
210
273
|
Plot weak scaling based on the data size.
|
|
211
274
|
"""
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
275
|
+
dataframes, available_gpu_counts, _ = clean_up_csv(
|
|
276
|
+
csv_files,
|
|
277
|
+
precisions,
|
|
278
|
+
functions,
|
|
279
|
+
fixed_gpu_size,
|
|
280
|
+
fixed_data_size,
|
|
281
|
+
pdims,
|
|
282
|
+
pdims_strategy,
|
|
283
|
+
backends,
|
|
284
|
+
memory_units,
|
|
285
|
+
)
|
|
286
|
+
if len(dataframes) == 0:
|
|
287
|
+
print(f"No dataframes found for the given arguments. Exiting...")
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
plot_scaling(
|
|
291
|
+
dataframes,
|
|
292
|
+
available_gpu_counts,
|
|
293
|
+
"x",
|
|
294
|
+
"gpus",
|
|
295
|
+
"Data size",
|
|
296
|
+
"Number of GPUs",
|
|
297
|
+
figure_size,
|
|
298
|
+
output,
|
|
299
|
+
dark_bg,
|
|
300
|
+
print_decompositions,
|
|
301
|
+
backends,
|
|
302
|
+
precisions,
|
|
303
|
+
functions,
|
|
304
|
+
plot_columns,
|
|
305
|
+
memory_units,
|
|
306
|
+
label_text,
|
|
307
|
+
pdims_strategy,
|
|
308
|
+
)
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -16,36 +16,45 @@ from tabulate import tabulate
|
|
|
16
16
|
|
|
17
17
|
class Timer:
|
|
18
18
|
|
|
19
|
-
def __init__(self, save_jaxpr=False):
|
|
20
|
-
self.jit_time =
|
|
19
|
+
def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
|
|
20
|
+
self.jit_time = 0.0
|
|
21
21
|
self.times = []
|
|
22
22
|
self.profiling_data = {}
|
|
23
23
|
self.compiled_code = {}
|
|
24
24
|
self.save_jaxpr = save_jaxpr
|
|
25
|
+
self.jax_fn = jax_fn
|
|
26
|
+
self.devices = devices
|
|
25
27
|
|
|
26
28
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
27
29
|
|
|
28
|
-
|
|
30
|
+
if not self.jax_fn:
|
|
31
|
+
return memory_analysis
|
|
32
|
+
|
|
33
|
+
sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
29
34
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
30
|
-
factor = 0 if memory_analysis == 0 else int(
|
|
35
|
+
factor = 0 if memory_analysis == 0 else int(
|
|
36
|
+
np.log10(memory_analysis) // 3)
|
|
31
37
|
|
|
32
38
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
33
39
|
|
|
34
40
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
35
41
|
if memory_analysis is None:
|
|
36
42
|
return None, None, None, None
|
|
37
|
-
return (
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
43
|
+
return (
|
|
44
|
+
memory_analysis.generated_code_size_in_bytes,
|
|
45
|
+
memory_analysis.argument_size_in_bytes,
|
|
46
|
+
memory_analysis.output_size_in_bytes,
|
|
47
|
+
memory_analysis.temp_size_in_bytes,
|
|
48
|
+
)
|
|
41
49
|
|
|
42
50
|
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
43
51
|
start = time.perf_counter()
|
|
44
52
|
out = fun(*args)
|
|
45
|
-
if
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
53
|
+
if self.jax_fn:
|
|
54
|
+
if ndarray_arg is None:
|
|
55
|
+
out.block_until_ready()
|
|
56
|
+
else:
|
|
57
|
+
out[ndarray_arg].block_until_ready()
|
|
49
58
|
end = time.perf_counter()
|
|
50
59
|
self.jit_time = (end - start) * 1e3
|
|
51
60
|
|
|
@@ -53,78 +62,90 @@ class Timer:
|
|
|
53
62
|
jaxpr = make_jaxpr(fun)(*args)
|
|
54
63
|
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
55
64
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
if self.jax_fn:
|
|
66
|
+
lowered = jax.jit(fun).lower(*args)
|
|
67
|
+
compiled = lowered.compile()
|
|
68
|
+
memory_analysis = self._read_memory_analysis(
|
|
69
|
+
compiled.memory_analysis())
|
|
70
|
+
|
|
71
|
+
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
72
|
+
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
73
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
74
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
75
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
76
|
+
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
77
|
+
|
|
67
78
|
return out
|
|
68
79
|
|
|
69
80
|
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
70
81
|
start = time.perf_counter()
|
|
71
82
|
out = fun(*args)
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
83
|
+
if self.jax_fn:
|
|
84
|
+
if ndarray_arg is None:
|
|
85
|
+
out.block_until_ready()
|
|
86
|
+
else:
|
|
87
|
+
out[ndarray_arg].block_until_ready()
|
|
76
88
|
end = time.perf_counter()
|
|
77
89
|
self.times.append((end - start) * 1e3)
|
|
78
90
|
return out
|
|
79
91
|
|
|
80
92
|
def _get_mean_times(self) -> np.ndarray:
|
|
81
|
-
if jax.device_count() == 1:
|
|
93
|
+
if jax.device_count() == 1 or jax.process_count() == 1:
|
|
82
94
|
return np.array(self.times)
|
|
83
95
|
|
|
84
|
-
devices
|
|
85
|
-
|
|
86
|
-
|
|
96
|
+
if self.devices is None:
|
|
97
|
+
self.devices = jax.devices()
|
|
98
|
+
|
|
99
|
+
mesh = jax.make_mesh((len(self.devices), ), ("x", ),
|
|
100
|
+
devices=self.devices)
|
|
101
|
+
sharding = NamedSharding(mesh, P("x"))
|
|
87
102
|
|
|
88
103
|
times_array = jnp.array(self.times)
|
|
89
104
|
global_shape = (jax.device_count(), times_array.shape[0])
|
|
90
105
|
global_times = jax.make_array_from_callback(
|
|
91
106
|
shape=global_shape,
|
|
92
107
|
sharding=sharding,
|
|
93
|
-
data_callback=lambda _: jnp.expand_dims(times_array, axis=0)
|
|
108
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
|
|
109
|
+
)
|
|
94
110
|
|
|
95
111
|
@partial(shard_map,
|
|
96
112
|
mesh=mesh,
|
|
97
|
-
in_specs=P(
|
|
113
|
+
in_specs=P("x"),
|
|
98
114
|
out_specs=P(),
|
|
99
115
|
check_rep=False)
|
|
100
116
|
def get_mean_times(times):
|
|
101
|
-
return jax.lax.pmean(times, axis_name=
|
|
117
|
+
return jax.lax.pmean(times, axis_name="x")
|
|
102
118
|
|
|
103
119
|
times_array = get_mean_times(global_times)
|
|
104
120
|
times_array.block_until_ready()
|
|
105
121
|
return np.array(times_array.addressable_data(0)[0])
|
|
106
122
|
|
|
107
|
-
def report(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
123
|
+
def report(
|
|
124
|
+
self,
|
|
125
|
+
csv_filename: str,
|
|
126
|
+
function: str,
|
|
127
|
+
x: int,
|
|
128
|
+
y: int | None = None,
|
|
129
|
+
z: int | None = None,
|
|
130
|
+
precision: str = "float32",
|
|
131
|
+
px: int = 1,
|
|
132
|
+
py: int = 1,
|
|
133
|
+
backend: str = "NCCL",
|
|
134
|
+
nodes: int = 1,
|
|
135
|
+
md_filename: str | None = None,
|
|
136
|
+
extra_info: dict = {},
|
|
137
|
+
):
|
|
120
138
|
|
|
121
139
|
if md_filename is None:
|
|
122
|
-
dirname, filename =
|
|
123
|
-
|
|
124
|
-
|
|
140
|
+
dirname, filename = (
|
|
141
|
+
os.path.dirname(csv_filename),
|
|
142
|
+
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
143
|
+
)
|
|
125
144
|
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
126
145
|
os.makedirs(report_folder, exist_ok=True)
|
|
127
|
-
md_filename =
|
|
146
|
+
md_filename = (
|
|
147
|
+
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
148
|
+
)
|
|
128
149
|
|
|
129
150
|
y = x if y is None else y
|
|
130
151
|
z = x if z is None else z
|
|
@@ -138,10 +159,17 @@ class Timer:
|
|
|
138
159
|
std_time = np.std(times_array)
|
|
139
160
|
last_time = times_array[-1]
|
|
140
161
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
162
|
+
if self.jax_fn:
|
|
163
|
+
|
|
164
|
+
generated_code = self.profiling_data["generated_code"]
|
|
165
|
+
argument_size = self.profiling_data["argument_size"]
|
|
166
|
+
output_size = self.profiling_data["output_size"]
|
|
167
|
+
temp_size = self.profiling_data["temp_size"]
|
|
168
|
+
else:
|
|
169
|
+
generated_code = "N/A"
|
|
170
|
+
argument_size = "N/A"
|
|
171
|
+
output_size = "N/A"
|
|
172
|
+
temp_size = "N/A"
|
|
145
173
|
|
|
146
174
|
csv_line = (
|
|
147
175
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
@@ -149,7 +177,7 @@ class Timer:
|
|
|
149
177
|
f"{generated_code},{argument_size},{output_size},{temp_size}\n"
|
|
150
178
|
)
|
|
151
179
|
|
|
152
|
-
with open(csv_filename,
|
|
180
|
+
with open(csv_filename, "a") as f:
|
|
153
181
|
f.write(csv_line)
|
|
154
182
|
|
|
155
183
|
param_dict = {
|
|
@@ -175,44 +203,56 @@ class Timer:
|
|
|
175
203
|
"Argument Size": self._normalize_memory_units(argument_size),
|
|
176
204
|
"Output Size": self._normalize_memory_units(output_size),
|
|
177
205
|
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
178
|
-
"FLOPS": self.profiling_data["FLOPS"]
|
|
179
206
|
}
|
|
180
207
|
iteration_runs = {}
|
|
181
208
|
for i in range(len(times_array)):
|
|
182
209
|
iteration_runs[f"Run {i}"] = times_array[i]
|
|
183
210
|
|
|
184
|
-
with open(md_filename,
|
|
211
|
+
with open(md_filename, "w") as f:
|
|
185
212
|
f.write(f"# Reporting for {function}\n")
|
|
186
213
|
f.write(f"## Parameters\n")
|
|
187
214
|
f.write(
|
|
188
|
-
tabulate(
|
|
189
|
-
|
|
190
|
-
|
|
215
|
+
tabulate(
|
|
216
|
+
param_dict.items(),
|
|
217
|
+
headers=["Parameter", "Value"],
|
|
218
|
+
tablefmt="github",
|
|
219
|
+
))
|
|
191
220
|
f.write("\n---\n")
|
|
192
221
|
f.write(f"## Profiling Data\n")
|
|
193
222
|
f.write(
|
|
194
|
-
tabulate(
|
|
195
|
-
|
|
196
|
-
|
|
223
|
+
tabulate(
|
|
224
|
+
profiling_result.items(),
|
|
225
|
+
headers=["Parameter", "Value"],
|
|
226
|
+
tablefmt="github",
|
|
227
|
+
))
|
|
197
228
|
f.write("\n---\n")
|
|
198
229
|
f.write(f"## Iteration Runs\n")
|
|
199
230
|
f.write(
|
|
200
|
-
tabulate(
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
f.write(f"```hlo\n")
|
|
211
|
-
f.write(self.compiled_code["LOWERED"])
|
|
212
|
-
f.write(f"\n```\n")
|
|
213
|
-
f.write("\n---\n")
|
|
214
|
-
if self.save_jaxpr:
|
|
215
|
-
f.write(f"## JAXPR\n")
|
|
216
|
-
f.write(f"```haskel\n")
|
|
217
|
-
f.write(self.compiled_code["JAXPR"])
|
|
231
|
+
tabulate(
|
|
232
|
+
iteration_runs.items(),
|
|
233
|
+
headers=["Iteration", "Time"],
|
|
234
|
+
tablefmt="github",
|
|
235
|
+
))
|
|
236
|
+
if self.jax_fn:
|
|
237
|
+
f.write("\n---\n")
|
|
238
|
+
f.write(f"## Compiled Code\n")
|
|
239
|
+
f.write(f"```hlo\n")
|
|
240
|
+
f.write(self.compiled_code["COMPILED"])
|
|
218
241
|
f.write(f"\n```\n")
|
|
242
|
+
f.write("\n---\n")
|
|
243
|
+
f.write(f"## Lowered Code\n")
|
|
244
|
+
f.write(f"```hlo\n")
|
|
245
|
+
f.write(self.compiled_code["LOWERED"])
|
|
246
|
+
f.write(f"\n```\n")
|
|
247
|
+
f.write("\n---\n")
|
|
248
|
+
if self.save_jaxpr:
|
|
249
|
+
f.write(f"## JAXPR\n")
|
|
250
|
+
f.write(f"```haskel\n")
|
|
251
|
+
f.write(self.compiled_code["JAXPR"])
|
|
252
|
+
f.write(f"\n```\n")
|
|
253
|
+
|
|
254
|
+
# Reset the timer
|
|
255
|
+
self.jit_time = 0.0
|
|
256
|
+
self.times = []
|
|
257
|
+
self.profiling_data = {}
|
|
258
|
+
self.compiled_code = {}
|
jax_hpc_profiler/utils.py
CHANGED
|
@@ -250,7 +250,7 @@ def clean_up_csv(
|
|
|
250
250
|
pdims_strategy: List[str] = ['plot_fastest'],
|
|
251
251
|
backends: List[str] = ['MPI', 'NCCL', 'MPI4JAX'],
|
|
252
252
|
memory_units: str = 'KB',
|
|
253
|
-
) -> Tuple[Dict[str, pd.DataFrame],
|
|
253
|
+
) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
|
|
254
254
|
"""
|
|
255
255
|
Clean up and aggregate data from CSV files.
|
|
256
256
|
|
|
@@ -341,7 +341,6 @@ def clean_up_csv(
|
|
|
341
341
|
if pdims:
|
|
342
342
|
px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
|
|
343
343
|
df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
|
|
344
|
-
|
|
345
344
|
# convert memory units columns to remquested memory_units
|
|
346
345
|
match memory_units:
|
|
347
346
|
case 'KB':
|
|
@@ -385,6 +384,7 @@ def clean_up_csv(
|
|
|
385
384
|
df.drop(columns=['px', 'py'], inplace=True)
|
|
386
385
|
if not 'plot_all' in pdims_strategy:
|
|
387
386
|
df = df[df['decomp'].isin(pdims_strategy)]
|
|
387
|
+
|
|
388
388
|
# check available gpus in dataset
|
|
389
389
|
available_gpu_counts.update(df['gpus'].unique())
|
|
390
390
|
available_data_sizes.update(df['x'].unique())
|
|
@@ -394,4 +394,17 @@ def clean_up_csv(
|
|
|
394
394
|
else:
|
|
395
395
|
dataframes[file_name] = pd.concat([dataframes[file_name], df])
|
|
396
396
|
|
|
397
|
+
print(f"requested GPUS: {gpus} available GPUS: {available_gpu_counts}")
|
|
398
|
+
print(
|
|
399
|
+
f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
available_gpu_counts = (available_gpu_counts if gpus is None else [
|
|
403
|
+
gpu for gpu in gpus if gpu in available_gpu_counts
|
|
404
|
+
])
|
|
405
|
+
available_data_sizes = (available_data_sizes if data_sizes is None else [
|
|
406
|
+
data_size for data_size in data_sizes
|
|
407
|
+
if data_size in available_data_sizes
|
|
408
|
+
])
|
|
409
|
+
|
|
397
410
|
return dataframes, available_gpu_counts, available_data_sizes
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.9
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
|
-
License:
|
|
6
|
+
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
8
8
|
|
|
9
9
|
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=CSdl76LvaTVVn43dkwpVyiIkyl4lHlDCiI5jvUrIoj0,6059
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=2zPVTGRgFkYV75EJA1eoOqf92gCRXAtg-28cFgRy3Bw,2164
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=8ELOB_Yv_AdSVWtS-jrRNm0HtK5FgKwf_ljeNRfdp14,9087
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=p7MUcbd2H4_tRAevhG9T4jJ8XL-liComvJn2sis4psM,9209
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=hSsS34i46WdCR9XRW1-02fI_k0RUty78imnI-xAc-tY,14644
|
|
7
|
+
jax_hpc_profiler-0.2.9.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.9.dist-info/METADATA,sha256=CelrNVm13lK7L1ZkdOqD8Tm7qLBIF1oCyaghzDdrLRg,49270
|
|
9
|
+
jax_hpc_profiler-0.2.9.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
10
|
+
jax_hpc_profiler-0.2.9.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.9.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.9.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=cwHznCZ2pF2J7AtyUOB3pASnahKBLRWHAPGXmGDvWas,8360
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=okWQUJHblUKkYnw7j7wJ75PSbhVItXKkTMKjj0BmgR0,14132
|
|
7
|
-
jax_hpc_profiler-0.2.7.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.7.dist-info/METADATA,sha256=bQkpy5Kr8ybEM7GU7qR0FEnDV7xsLbrq98GRDfgDTQU,49250
|
|
9
|
-
jax_hpc_profiler-0.2.7.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
|
10
|
-
jax_hpc_profiler-0.2.7.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.7.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|