jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +9 -3
- jax_hpc_profiler/create_argparse.py +128 -120
- jax_hpc_profiler/main.py +41 -22
- jax_hpc_profiler/plotting.py +250 -68
- jax_hpc_profiler/timer.py +117 -126
- jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/METADATA +36 -4
- jax_hpc_profiler-0.3.0.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.12.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/__init__.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
from .create_argparse import create_argparser
|
|
2
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
3
|
from .timer import Timer
|
|
4
4
|
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
|
-
'create_argparser',
|
|
8
|
-
'
|
|
7
|
+
'create_argparser',
|
|
8
|
+
'plot_strong_scaling',
|
|
9
|
+
'plot_weak_scaling',
|
|
10
|
+
'plot_weak_fixed_scaling',
|
|
11
|
+
'Timer',
|
|
12
|
+
'clean_up_csv',
|
|
13
|
+
'concatenate_csvs',
|
|
14
|
+
'plot_with_pdims_strategy',
|
|
9
15
|
]
|
|
@@ -10,201 +10,209 @@ def create_argparser():
|
|
|
10
10
|
argparse.Namespace
|
|
11
11
|
Parsed and validated arguments.
|
|
12
12
|
"""
|
|
13
|
-
parser = argparse.ArgumentParser(
|
|
14
|
-
description="HPC Plotter for benchmarking data")
|
|
13
|
+
parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
|
|
15
14
|
|
|
16
15
|
# Group for concatenation to ensure mutually exclusive behavior
|
|
17
|
-
subparsers = parser.add_subparsers(dest=
|
|
16
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
18
17
|
|
|
19
|
-
concat_parser = subparsers.add_parser(
|
|
20
|
-
|
|
21
|
-
concat_parser.add_argument(
|
|
22
|
-
type=str,
|
|
23
|
-
help="Input directory for concatenation")
|
|
24
|
-
concat_parser.add_argument("output",
|
|
25
|
-
type=str,
|
|
26
|
-
help="Output directory for concatenation")
|
|
18
|
+
concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
|
|
19
|
+
concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
|
|
20
|
+
concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
|
|
27
21
|
|
|
28
22
|
# Arguments for plotting
|
|
29
|
-
plot_parser = subparsers.add_parser(
|
|
30
|
-
plot_parser.add_argument(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
"--gpus",
|
|
38
|
-
nargs="*",
|
|
23
|
+
plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
|
|
24
|
+
plot_parser.add_argument(
|
|
25
|
+
'-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
|
|
26
|
+
)
|
|
27
|
+
plot_parser.add_argument(
|
|
28
|
+
'-g',
|
|
29
|
+
'--gpus',
|
|
30
|
+
nargs='*',
|
|
39
31
|
type=int,
|
|
40
|
-
help=
|
|
32
|
+
help='List of number of GPUs to plot',
|
|
41
33
|
default=None,
|
|
42
34
|
)
|
|
43
35
|
plot_parser.add_argument(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
nargs=
|
|
36
|
+
'-d',
|
|
37
|
+
'--data_size',
|
|
38
|
+
nargs='*',
|
|
47
39
|
type=int,
|
|
48
|
-
help=
|
|
40
|
+
help='List of data sizes to plot',
|
|
49
41
|
default=None,
|
|
50
42
|
)
|
|
51
43
|
|
|
52
44
|
# pdims related arguments
|
|
53
45
|
plot_parser.add_argument(
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
nargs=
|
|
57
|
-
help=
|
|
46
|
+
'-fd',
|
|
47
|
+
'--filter_pdims',
|
|
48
|
+
nargs='*',
|
|
49
|
+
help='List of pdims to filter, e.g., 1x4 2x2 4x8',
|
|
58
50
|
default=None,
|
|
59
51
|
)
|
|
60
52
|
plot_parser.add_argument(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
choices=[
|
|
64
|
-
nargs=
|
|
65
|
-
default=[
|
|
66
|
-
help=
|
|
53
|
+
'-ps',
|
|
54
|
+
'--pdim_strategy',
|
|
55
|
+
choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
|
|
56
|
+
nargs='*',
|
|
57
|
+
default=['plot_fastest'],
|
|
58
|
+
help='Strategy for plotting pdims',
|
|
67
59
|
)
|
|
68
60
|
|
|
69
61
|
# Function and precision related arguments
|
|
70
62
|
plot_parser.add_argument(
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
choices=[
|
|
74
|
-
default=[
|
|
75
|
-
nargs=
|
|
76
|
-
help=
|
|
63
|
+
'-pr',
|
|
64
|
+
'--precision',
|
|
65
|
+
choices=['float32', 'float64'],
|
|
66
|
+
default=['float32', 'float64'],
|
|
67
|
+
nargs='*',
|
|
68
|
+
help='Precision to filter by (float32 or float64)',
|
|
77
69
|
)
|
|
78
70
|
plot_parser.add_argument(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
nargs=
|
|
82
|
-
help=
|
|
71
|
+
'-fn',
|
|
72
|
+
'--function_name',
|
|
73
|
+
nargs='+',
|
|
74
|
+
help='Function names to filter',
|
|
83
75
|
default=None,
|
|
84
76
|
)
|
|
85
77
|
|
|
86
78
|
# Time or memory related arguments
|
|
87
79
|
plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
|
|
88
80
|
plotting_group.add_argument(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
nargs=
|
|
81
|
+
'-pt',
|
|
82
|
+
'--plot_times',
|
|
83
|
+
nargs='*',
|
|
92
84
|
choices=[
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
85
|
+
'jit_time',
|
|
86
|
+
'min_time',
|
|
87
|
+
'max_time',
|
|
88
|
+
'mean_time',
|
|
89
|
+
'std_time',
|
|
90
|
+
'last_time',
|
|
99
91
|
],
|
|
100
|
-
help=
|
|
92
|
+
help='Time columns to plot',
|
|
101
93
|
)
|
|
102
94
|
plotting_group.add_argument(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
nargs=
|
|
106
|
-
choices=[
|
|
107
|
-
|
|
108
|
-
],
|
|
109
|
-
help="Memory columns to plot",
|
|
95
|
+
'-pm',
|
|
96
|
+
'--plot_memory',
|
|
97
|
+
nargs='*',
|
|
98
|
+
choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
|
|
99
|
+
help='Memory columns to plot',
|
|
110
100
|
)
|
|
111
101
|
plot_parser.add_argument(
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
default=
|
|
115
|
-
help=
|
|
102
|
+
'-mu',
|
|
103
|
+
'--memory_units',
|
|
104
|
+
default='GB',
|
|
105
|
+
help='Memory units to plot (KB, MB, GB, TB)',
|
|
116
106
|
)
|
|
117
107
|
|
|
118
108
|
# Plot customization arguments
|
|
119
|
-
plot_parser.add_argument(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
plot_parser.add_argument(
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
plot_parser.add_argument(
|
|
134
|
-
"-pd",
|
|
135
|
-
"--print_decompositions",
|
|
136
|
-
action="store_true",
|
|
137
|
-
help="Print decompositions on plot",
|
|
109
|
+
plot_parser.add_argument(
|
|
110
|
+
'-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
|
|
111
|
+
)
|
|
112
|
+
plot_parser.add_argument(
|
|
113
|
+
'-o', '--output', help='Output file (if none then only show plot)', default=None
|
|
114
|
+
)
|
|
115
|
+
plot_parser.add_argument(
|
|
116
|
+
'-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
|
|
117
|
+
)
|
|
118
|
+
plot_parser.add_argument(
|
|
119
|
+
'-pd',
|
|
120
|
+
'--print_decompositions',
|
|
121
|
+
action='store_true',
|
|
122
|
+
help='Print decompositions on plot',
|
|
138
123
|
)
|
|
139
124
|
|
|
140
125
|
# Backend related arguments
|
|
141
126
|
plot_parser.add_argument(
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
nargs=
|
|
145
|
-
default=[
|
|
146
|
-
help=
|
|
127
|
+
'-b',
|
|
128
|
+
'--backends',
|
|
129
|
+
nargs='*',
|
|
130
|
+
default=['MPI', 'NCCL', 'MPI4JAX'],
|
|
131
|
+
help='List of backends to include',
|
|
147
132
|
)
|
|
148
133
|
|
|
149
134
|
# Scaling type argument
|
|
150
135
|
plot_parser.add_argument(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
choices=[
|
|
136
|
+
'-sc',
|
|
137
|
+
'--scaling',
|
|
138
|
+
choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
|
|
154
139
|
required=True,
|
|
155
|
-
help=
|
|
140
|
+
help='Scaling type (Strong, Weak, or WeakFixed)',
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Weak-scaling specific options
|
|
144
|
+
plot_parser.add_argument(
|
|
145
|
+
'--weak_ideal_line',
|
|
146
|
+
action='store_true',
|
|
147
|
+
help='Overlay an ideal flat line for weak scaling (Weak mode only)',
|
|
148
|
+
)
|
|
149
|
+
plot_parser.add_argument(
|
|
150
|
+
'--weak_reverse_axes',
|
|
151
|
+
action='store_true',
|
|
152
|
+
help=(
|
|
153
|
+
'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
|
|
154
|
+
'of data size. Requires --gpus and --data_size with equal lengths.'
|
|
155
|
+
),
|
|
156
156
|
)
|
|
157
157
|
|
|
158
158
|
# Label customization argument
|
|
159
159
|
plot_parser.add_argument(
|
|
160
|
-
|
|
161
|
-
|
|
160
|
+
'-l',
|
|
161
|
+
'--label_text',
|
|
162
162
|
type=str,
|
|
163
|
-
help=
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
default=
|
|
163
|
+
help=(
|
|
164
|
+
'Custom label for the plot. You can use placeholders: %%decomposition%% '
|
|
165
|
+
'(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
|
|
166
|
+
'%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
|
|
167
|
+
),
|
|
168
|
+
default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
169
169
|
)
|
|
170
170
|
|
|
171
171
|
plot_parser.add_argument(
|
|
172
|
-
|
|
173
|
-
|
|
172
|
+
'-xl',
|
|
173
|
+
'--xlabel',
|
|
174
174
|
type=str,
|
|
175
|
-
help=
|
|
175
|
+
help='X-axis label for the plot',
|
|
176
176
|
)
|
|
177
177
|
plot_parser.add_argument(
|
|
178
|
-
|
|
179
|
-
|
|
178
|
+
'-tl',
|
|
179
|
+
'--title',
|
|
180
180
|
type=str,
|
|
181
|
-
help=
|
|
181
|
+
help='Title for the plot',
|
|
182
182
|
)
|
|
183
183
|
|
|
184
|
-
subparsers.add_parser(
|
|
184
|
+
subparsers.add_parser('label_help', help='Label customization help')
|
|
185
185
|
|
|
186
186
|
args = parser.parse_args()
|
|
187
187
|
|
|
188
188
|
# if command was plot, then check if pdim_strategy is validat
|
|
189
|
-
if args.command ==
|
|
190
|
-
if
|
|
189
|
+
if args.command == 'plot':
|
|
190
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
191
191
|
print(
|
|
192
|
-
"
|
|
192
|
+
"""
|
|
193
|
+
Warning: 'plot_all' strategy is combined with other strategies.
|
|
194
|
+
Using 'plot_all' only.
|
|
195
|
+
"""
|
|
193
196
|
)
|
|
194
|
-
args.pdim_strategy = [
|
|
197
|
+
args.pdim_strategy = ['plot_all']
|
|
195
198
|
|
|
196
|
-
if
|
|
197
|
-
args.pdim_strategy) > 1:
|
|
199
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
198
200
|
print(
|
|
199
|
-
"
|
|
201
|
+
"""
|
|
202
|
+
Warning: 'plot_fastest' strategy is combined with other strategies.
|
|
203
|
+
Using 'plot_fastest' only.
|
|
204
|
+
"""
|
|
200
205
|
)
|
|
201
|
-
args.pdim_strategy = [
|
|
206
|
+
args.pdim_strategy = ['plot_fastest']
|
|
202
207
|
if args.plot_times is not None:
|
|
203
208
|
args.plot_columns = args.plot_times
|
|
204
209
|
elif args.plot_memory is not None:
|
|
205
210
|
args.plot_columns = args.plot_memory
|
|
206
211
|
else:
|
|
207
|
-
raise ValueError(
|
|
208
|
-
|
|
212
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
213
|
+
|
|
214
|
+
# Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
|
|
215
|
+
# data_size are provided and have matching lengths. For Strong and
|
|
216
|
+
# WeakFixed, gpus/data_size remain optional as before.
|
|
209
217
|
|
|
210
218
|
return args
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -1,29 +1,26 @@
|
|
|
1
|
-
import sys
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
|
|
4
1
|
from .create_argparse import create_argparser
|
|
5
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
6
|
-
from .utils import
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
|
+
from .utils import concatenate_csvs
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
def main():
|
|
10
7
|
args = create_argparser()
|
|
11
8
|
|
|
12
|
-
if args.command ==
|
|
9
|
+
if args.command == 'concat':
|
|
13
10
|
input_dir, output_dir = args.input, args.output
|
|
14
11
|
concatenate_csvs(input_dir, output_dir)
|
|
15
|
-
elif args.command ==
|
|
16
|
-
print(
|
|
17
|
-
print(
|
|
18
|
-
print(
|
|
19
|
-
print(
|
|
20
|
-
print(
|
|
21
|
-
print(
|
|
22
|
-
print(
|
|
23
|
-
print(
|
|
24
|
-
elif args.command ==
|
|
25
|
-
|
|
26
|
-
if
|
|
12
|
+
elif args.command == 'label_help':
|
|
13
|
+
print('Customize the label text for the plot. using these commands.')
|
|
14
|
+
print(' -- %m% or %methodname%: method name')
|
|
15
|
+
print(' -- %f% or %function%: function name')
|
|
16
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
17
|
+
print(' -- %pr% or %precision%: precision')
|
|
18
|
+
print(' -- %b% or %backend%: backend')
|
|
19
|
+
print(' -- %p% or %pdims%: pdims')
|
|
20
|
+
print(' -- %n% or %node%: node')
|
|
21
|
+
elif args.command == 'plot':
|
|
22
|
+
scaling = args.scaling.lower()
|
|
23
|
+
if scaling in ('weak', 'w'):
|
|
27
24
|
plot_weak_scaling(
|
|
28
25
|
args.csv_files,
|
|
29
26
|
args.gpus,
|
|
@@ -37,13 +34,15 @@ def main():
|
|
|
37
34
|
args.plot_columns,
|
|
38
35
|
args.memory_units,
|
|
39
36
|
args.label_text,
|
|
37
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
40
38
|
args.title,
|
|
41
|
-
args.label_text,
|
|
42
39
|
args.figure_size,
|
|
43
40
|
args.dark_bg,
|
|
44
41
|
args.output,
|
|
42
|
+
args.weak_ideal_line,
|
|
43
|
+
args.weak_reverse_axes,
|
|
45
44
|
)
|
|
46
|
-
elif
|
|
45
|
+
elif scaling in ('strong', 's'):
|
|
47
46
|
plot_strong_scaling(
|
|
48
47
|
args.csv_files,
|
|
49
48
|
args.gpus,
|
|
@@ -57,13 +56,33 @@ def main():
|
|
|
57
56
|
args.plot_columns,
|
|
58
57
|
args.memory_units,
|
|
59
58
|
args.label_text,
|
|
60
|
-
args.
|
|
59
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
60
|
+
args.title if getattr(args, 'title', None) is not None else 'Data sizes',
|
|
61
|
+
args.figure_size,
|
|
62
|
+
args.dark_bg,
|
|
63
|
+
args.output,
|
|
64
|
+
)
|
|
65
|
+
elif scaling in ('weakfixed', 'wf'):
|
|
66
|
+
plot_weak_fixed_scaling(
|
|
67
|
+
args.csv_files,
|
|
68
|
+
args.gpus,
|
|
69
|
+
args.data_size,
|
|
70
|
+
args.function_name,
|
|
71
|
+
args.precision,
|
|
72
|
+
args.filter_pdims,
|
|
73
|
+
args.pdim_strategy,
|
|
74
|
+
args.print_decompositions,
|
|
75
|
+
args.backends,
|
|
76
|
+
args.plot_columns,
|
|
77
|
+
args.memory_units,
|
|
61
78
|
args.label_text,
|
|
79
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
|
|
80
|
+
args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
|
|
62
81
|
args.figure_size,
|
|
63
82
|
args.dark_bg,
|
|
64
83
|
args.output,
|
|
65
84
|
)
|
|
66
85
|
|
|
67
86
|
|
|
68
|
-
if __name__ ==
|
|
87
|
+
if __name__ == '__main__':
|
|
69
88
|
main()
|