jax-hpc-profiler 0.2.12__tar.gz → 0.2.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/pyproject.toml +28 -1
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/__init__.py +7 -2
- jax_hpc_profiler-0.2.13/src/jax_hpc_profiler/create_argparse.py +199 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/main.py +15 -19
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/plotting.py +58 -66
- jax_hpc_profiler-0.2.13/src/jax_hpc_profiler/timer.py +276 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/create_argparse.py +0 -210
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/timer.py +0 -289
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/README.md +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "jax_hpc_profiler"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.13"
|
|
8
8
|
description = "HPC Plotter and profiler for benchmarking data made for JAX"
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Wassim Kabalan" }
|
|
@@ -47,3 +47,30 @@ urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
|
|
|
47
47
|
|
|
48
48
|
[project.scripts]
|
|
49
49
|
jhp = "jax_hpc_profiler.main:main"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
[tool.ruff]
|
|
54
|
+
line-length = 100
|
|
55
|
+
fix = true # autofix issues
|
|
56
|
+
force-exclude = true # useful with ruff-pre-commit plugin
|
|
57
|
+
src = ["src"]
|
|
58
|
+
|
|
59
|
+
[tool.ruff.lint]
|
|
60
|
+
select = [
|
|
61
|
+
'ARG001', # flake8-unused-function-arguments
|
|
62
|
+
'E', # pycodestyle-errors
|
|
63
|
+
'F', # pyflakes
|
|
64
|
+
'I', # isort
|
|
65
|
+
'UP', # pyupgrade
|
|
66
|
+
'T10', # flake8-debugger
|
|
67
|
+
]
|
|
68
|
+
ignore = [
|
|
69
|
+
'E203',
|
|
70
|
+
'E731',
|
|
71
|
+
'E741',
|
|
72
|
+
'F722', # conflicts with jaxtyping Array annotations
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
[tool.ruff.format]
|
|
76
|
+
quote-style = 'single'
|
|
@@ -4,6 +4,11 @@ from .timer import Timer
|
|
|
4
4
|
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
|
-
'create_argparser',
|
|
8
|
-
'
|
|
7
|
+
'create_argparser',
|
|
8
|
+
'plot_strong_scaling',
|
|
9
|
+
'plot_weak_scaling',
|
|
10
|
+
'Timer',
|
|
11
|
+
'clean_up_csv',
|
|
12
|
+
'concatenate_csvs',
|
|
13
|
+
'plot_with_pdims_strategy',
|
|
9
14
|
]
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_argparser():
|
|
5
|
+
"""
|
|
6
|
+
Create argument parser for the HPC Plotter package.
|
|
7
|
+
|
|
8
|
+
Returns
|
|
9
|
+
-------
|
|
10
|
+
argparse.Namespace
|
|
11
|
+
Parsed and validated arguments.
|
|
12
|
+
"""
|
|
13
|
+
parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
|
|
14
|
+
|
|
15
|
+
# Group for concatenation to ensure mutually exclusive behavior
|
|
16
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
17
|
+
|
|
18
|
+
concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
|
|
19
|
+
concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
|
|
20
|
+
concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
|
|
21
|
+
|
|
22
|
+
# Arguments for plotting
|
|
23
|
+
plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
|
|
24
|
+
plot_parser.add_argument(
|
|
25
|
+
'-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
|
|
26
|
+
)
|
|
27
|
+
plot_parser.add_argument(
|
|
28
|
+
'-g',
|
|
29
|
+
'--gpus',
|
|
30
|
+
nargs='*',
|
|
31
|
+
type=int,
|
|
32
|
+
help='List of number of GPUs to plot',
|
|
33
|
+
default=None,
|
|
34
|
+
)
|
|
35
|
+
plot_parser.add_argument(
|
|
36
|
+
'-d',
|
|
37
|
+
'--data_size',
|
|
38
|
+
nargs='*',
|
|
39
|
+
type=int,
|
|
40
|
+
help='List of data sizes to plot',
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# pdims related arguments
|
|
45
|
+
plot_parser.add_argument(
|
|
46
|
+
'-fd',
|
|
47
|
+
'--filter_pdims',
|
|
48
|
+
nargs='*',
|
|
49
|
+
help='List of pdims to filter, e.g., 1x4 2x2 4x8',
|
|
50
|
+
default=None,
|
|
51
|
+
)
|
|
52
|
+
plot_parser.add_argument(
|
|
53
|
+
'-ps',
|
|
54
|
+
'--pdim_strategy',
|
|
55
|
+
choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
|
|
56
|
+
nargs='*',
|
|
57
|
+
default=['plot_fastest'],
|
|
58
|
+
help='Strategy for plotting pdims',
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Function and precision related arguments
|
|
62
|
+
plot_parser.add_argument(
|
|
63
|
+
'-pr',
|
|
64
|
+
'--precision',
|
|
65
|
+
choices=['float32', 'float64'],
|
|
66
|
+
default=['float32', 'float64'],
|
|
67
|
+
nargs='*',
|
|
68
|
+
help='Precision to filter by (float32 or float64)',
|
|
69
|
+
)
|
|
70
|
+
plot_parser.add_argument(
|
|
71
|
+
'-fn',
|
|
72
|
+
'--function_name',
|
|
73
|
+
nargs='+',
|
|
74
|
+
help='Function names to filter',
|
|
75
|
+
default=None,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Time or memory related arguments
|
|
79
|
+
plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
|
|
80
|
+
plotting_group.add_argument(
|
|
81
|
+
'-pt',
|
|
82
|
+
'--plot_times',
|
|
83
|
+
nargs='*',
|
|
84
|
+
choices=[
|
|
85
|
+
'jit_time',
|
|
86
|
+
'min_time',
|
|
87
|
+
'max_time',
|
|
88
|
+
'mean_time',
|
|
89
|
+
'std_time',
|
|
90
|
+
'last_time',
|
|
91
|
+
],
|
|
92
|
+
help='Time columns to plot',
|
|
93
|
+
)
|
|
94
|
+
plotting_group.add_argument(
|
|
95
|
+
'-pm',
|
|
96
|
+
'--plot_memory',
|
|
97
|
+
nargs='*',
|
|
98
|
+
choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
|
|
99
|
+
help='Memory columns to plot',
|
|
100
|
+
)
|
|
101
|
+
plot_parser.add_argument(
|
|
102
|
+
'-mu',
|
|
103
|
+
'--memory_units',
|
|
104
|
+
default='GB',
|
|
105
|
+
help='Memory units to plot (KB, MB, GB, TB)',
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Plot customization arguments
|
|
109
|
+
plot_parser.add_argument(
|
|
110
|
+
'-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
|
|
111
|
+
)
|
|
112
|
+
plot_parser.add_argument(
|
|
113
|
+
'-o', '--output', help='Output file (if none then only show plot)', default=None
|
|
114
|
+
)
|
|
115
|
+
plot_parser.add_argument(
|
|
116
|
+
'-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
|
|
117
|
+
)
|
|
118
|
+
plot_parser.add_argument(
|
|
119
|
+
'-pd',
|
|
120
|
+
'--print_decompositions',
|
|
121
|
+
action='store_true',
|
|
122
|
+
help='Print decompositions on plot',
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Backend related arguments
|
|
126
|
+
plot_parser.add_argument(
|
|
127
|
+
'-b',
|
|
128
|
+
'--backends',
|
|
129
|
+
nargs='*',
|
|
130
|
+
default=['MPI', 'NCCL', 'MPI4JAX'],
|
|
131
|
+
help='List of backends to include',
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Scaling type argument
|
|
135
|
+
plot_parser.add_argument(
|
|
136
|
+
'-sc',
|
|
137
|
+
'--scaling',
|
|
138
|
+
choices=['Weak', 'Strong', 'w', 's'],
|
|
139
|
+
required=True,
|
|
140
|
+
help='Scaling type (Weak or Strong)',
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Label customization argument
|
|
144
|
+
plot_parser.add_argument(
|
|
145
|
+
'-l',
|
|
146
|
+
'--label_text',
|
|
147
|
+
type=str,
|
|
148
|
+
help=(
|
|
149
|
+
'Custom label for the plot. You can use placeholders: %%decomposition%% '
|
|
150
|
+
'(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
|
|
151
|
+
'%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
|
|
152
|
+
),
|
|
153
|
+
default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
plot_parser.add_argument(
|
|
157
|
+
'-xl',
|
|
158
|
+
'--xlabel',
|
|
159
|
+
type=str,
|
|
160
|
+
help='X-axis label for the plot',
|
|
161
|
+
)
|
|
162
|
+
plot_parser.add_argument(
|
|
163
|
+
'-tl',
|
|
164
|
+
'--title',
|
|
165
|
+
type=str,
|
|
166
|
+
help='Title for the plot',
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
subparsers.add_parser('label_help', help='Label customization help')
|
|
170
|
+
|
|
171
|
+
args = parser.parse_args()
|
|
172
|
+
|
|
173
|
+
# if command was plot, then check if pdim_strategy is validat
|
|
174
|
+
if args.command == 'plot':
|
|
175
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
176
|
+
print(
|
|
177
|
+
"""
|
|
178
|
+
Warning: 'plot_all' strategy is combined with other strategies.
|
|
179
|
+
Using 'plot_all' only.
|
|
180
|
+
"""
|
|
181
|
+
)
|
|
182
|
+
args.pdim_strategy = ['plot_all']
|
|
183
|
+
|
|
184
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
185
|
+
print(
|
|
186
|
+
"""
|
|
187
|
+
Warning: 'plot_fastest' strategy is combined with other strategies.
|
|
188
|
+
Using 'plot_fastest' only.
|
|
189
|
+
"""
|
|
190
|
+
)
|
|
191
|
+
args.pdim_strategy = ['plot_fastest']
|
|
192
|
+
if args.plot_times is not None:
|
|
193
|
+
args.plot_columns = args.plot_times
|
|
194
|
+
elif args.plot_memory is not None:
|
|
195
|
+
args.plot_columns = args.plot_memory
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
198
|
+
|
|
199
|
+
return args
|
|
@@ -1,29 +1,25 @@
|
|
|
1
|
-
import sys
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
|
|
4
1
|
from .create_argparse import create_argparser
|
|
5
2
|
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
6
|
-
from .utils import
|
|
3
|
+
from .utils import concatenate_csvs
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
def main():
|
|
10
7
|
args = create_argparser()
|
|
11
8
|
|
|
12
|
-
if args.command ==
|
|
9
|
+
if args.command == 'concat':
|
|
13
10
|
input_dir, output_dir = args.input, args.output
|
|
14
11
|
concatenate_csvs(input_dir, output_dir)
|
|
15
|
-
elif args.command ==
|
|
16
|
-
print(
|
|
17
|
-
print(
|
|
18
|
-
print(
|
|
19
|
-
print(
|
|
20
|
-
print(
|
|
21
|
-
print(
|
|
22
|
-
print(
|
|
23
|
-
print(
|
|
24
|
-
elif args.command ==
|
|
25
|
-
|
|
26
|
-
if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
|
|
12
|
+
elif args.command == 'label_help':
|
|
13
|
+
print('Customize the label text for the plot. using these commands.')
|
|
14
|
+
print(' -- %m% or %methodname%: method name')
|
|
15
|
+
print(' -- %f% or %function%: function name')
|
|
16
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
17
|
+
print(' -- %pr% or %precision%: precision')
|
|
18
|
+
print(' -- %b% or %backend%: backend')
|
|
19
|
+
print(' -- %p% or %pdims%: pdims')
|
|
20
|
+
print(' -- %n% or %node%: node')
|
|
21
|
+
elif args.command == 'plot':
|
|
22
|
+
if args.scaling.lower() == 'weak' or args.scaling.lower() == 'w':
|
|
27
23
|
plot_weak_scaling(
|
|
28
24
|
args.csv_files,
|
|
29
25
|
args.gpus,
|
|
@@ -43,7 +39,7 @@ def main():
|
|
|
43
39
|
args.dark_bg,
|
|
44
40
|
args.output,
|
|
45
41
|
)
|
|
46
|
-
elif args.scaling.lower() ==
|
|
42
|
+
elif args.scaling.lower() == 'strong' or args.scaling.lower() == 's':
|
|
47
43
|
plot_strong_scaling(
|
|
48
44
|
args.csv_files,
|
|
49
45
|
args.gpus,
|
|
@@ -65,5 +61,5 @@ def main():
|
|
|
65
61
|
)
|
|
66
62
|
|
|
67
63
|
|
|
68
|
-
if __name__ ==
|
|
64
|
+
if __name__ == '__main__':
|
|
69
65
|
main()
|
|
@@ -4,23 +4,22 @@ from typing import Dict, List, Optional
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
-
import seaborn as sns
|
|
8
7
|
from matplotlib.axes import Axes
|
|
9
8
|
from matplotlib.patches import FancyBboxPatch
|
|
10
9
|
|
|
11
|
-
from .utils import clean_up_csv,
|
|
10
|
+
from .utils import clean_up_csv, plot_with_pdims_strategy
|
|
12
11
|
|
|
13
|
-
np.seterr(divide=
|
|
12
|
+
np.seterr(divide='ignore')
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
def configure_axes(
|
|
17
16
|
ax: Axes,
|
|
18
17
|
x_values: List[int],
|
|
19
18
|
y_values: List[float],
|
|
20
|
-
title: str,
|
|
19
|
+
title: Optional[str],
|
|
21
20
|
xlabel: str,
|
|
22
21
|
plotting_memory: bool = False,
|
|
23
|
-
memory_units: str =
|
|
22
|
+
memory_units: str = 'bytes',
|
|
24
23
|
):
|
|
25
24
|
"""
|
|
26
25
|
Configure the axes for the plot.
|
|
@@ -36,33 +35,32 @@ def configure_axes(
|
|
|
36
35
|
xlabel : str
|
|
37
36
|
The label for the x-axis.
|
|
38
37
|
"""
|
|
39
|
-
ylabel =
|
|
40
|
-
|
|
41
|
-
f2
|
|
42
|
-
|
|
38
|
+
ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
|
|
39
|
+
|
|
40
|
+
def f2(x):
|
|
41
|
+
return np.log2(x)
|
|
42
|
+
|
|
43
|
+
def g2(x):
|
|
44
|
+
return 2**x
|
|
45
|
+
|
|
43
46
|
ax.set_xlim([min(x_values), max(x_values)])
|
|
44
47
|
y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
|
|
45
48
|
ax.set_title(title)
|
|
46
49
|
ax.set_ylim([y_min, y_max])
|
|
47
|
-
ax.set_xscale(
|
|
50
|
+
ax.set_xscale('function', functions=(f2, g2))
|
|
48
51
|
if not plotting_memory:
|
|
49
|
-
ax.set_yscale(
|
|
52
|
+
ax.set_yscale('symlog')
|
|
50
53
|
time_ticks = [
|
|
51
|
-
10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
|
|
52
|
-
int(np.ceil(np.log10(y_max))))
|
|
54
|
+
10**t for t in range(int(np.floor(np.log10(y_min))), 1 + int(np.ceil(np.log10(y_max))))
|
|
53
55
|
]
|
|
54
56
|
ax.set_yticks(time_ticks)
|
|
55
57
|
ax.set_xticks(x_values)
|
|
56
58
|
ax.set_xlabel(xlabel)
|
|
57
59
|
ax.set_ylabel(ylabel)
|
|
58
60
|
for x_value in x_values:
|
|
59
|
-
ax.axvline(x=x_value, color=
|
|
61
|
+
ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
|
|
60
62
|
ax.legend(
|
|
61
|
-
loc=
|
|
62
|
-
bbox_to_anchor=(0.5, 0.05),
|
|
63
|
-
ncol=4,
|
|
64
|
-
fontsize="x-large",
|
|
65
|
-
prop={"size": 14},
|
|
63
|
+
loc='best',
|
|
66
64
|
)
|
|
67
65
|
|
|
68
66
|
|
|
@@ -80,10 +78,10 @@ def plot_scaling(
|
|
|
80
78
|
backends: Optional[List[str]] = None,
|
|
81
79
|
precisions: Optional[List[str]] = None,
|
|
82
80
|
functions: Optional[List[str]] = None,
|
|
83
|
-
plot_columns: List[str] = [
|
|
84
|
-
memory_units: str =
|
|
85
|
-
label_text: str =
|
|
86
|
-
pdims_strategy: List[str] = [
|
|
81
|
+
plot_columns: List[str] = ['mean_time'],
|
|
82
|
+
memory_units: str = 'bytes',
|
|
83
|
+
label_text: str = 'plot',
|
|
84
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
87
85
|
):
|
|
88
86
|
"""
|
|
89
87
|
General scaling plot function based on the number of GPUs or data size.
|
|
@@ -115,7 +113,7 @@ def plot_scaling(
|
|
|
115
113
|
"""
|
|
116
114
|
|
|
117
115
|
if dark_bg:
|
|
118
|
-
plt.style.use(
|
|
116
|
+
plt.style.use('dark_background')
|
|
119
117
|
|
|
120
118
|
num_subplots = len(fixed_sizes)
|
|
121
119
|
num_rows = int(np.ceil(np.sqrt(num_subplots)))
|
|
@@ -133,28 +131,26 @@ def plot_scaling(
|
|
|
133
131
|
x_values = []
|
|
134
132
|
y_values = []
|
|
135
133
|
for method, df in dataframes.items():
|
|
136
|
-
|
|
137
134
|
filtered_method_df = df[df[fixed_column] == int(fixed_size)]
|
|
138
135
|
if filtered_method_df.empty:
|
|
139
136
|
continue
|
|
140
|
-
filtered_method_df = filtered_method_df.sort_values(
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
precisions = (
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
137
|
+
filtered_method_df = filtered_method_df.sort_values(by=[size_column])
|
|
138
|
+
functions = (
|
|
139
|
+
pd.unique(filtered_method_df['function']) if functions is None else functions
|
|
140
|
+
)
|
|
141
|
+
precisions = (
|
|
142
|
+
pd.unique(filtered_method_df['precision']) if precisions is None else precisions
|
|
143
|
+
)
|
|
144
|
+
backends = pd.unique(filtered_method_df['backend']) if backends is None else backends
|
|
148
145
|
|
|
149
|
-
combinations = product(backends, precisions, functions,
|
|
150
|
-
plot_columns)
|
|
146
|
+
combinations = product(backends, precisions, functions, plot_columns)
|
|
151
147
|
|
|
152
148
|
for backend, precision, function, plot_column in combinations:
|
|
153
|
-
|
|
154
149
|
filtered_params_df = filtered_method_df[
|
|
155
|
-
(filtered_method_df[
|
|
156
|
-
& (filtered_method_df[
|
|
157
|
-
& (filtered_method_df[
|
|
150
|
+
(filtered_method_df['backend'] == backend)
|
|
151
|
+
& (filtered_method_df['precision'] == precision)
|
|
152
|
+
& (filtered_method_df['function'] == function)
|
|
153
|
+
]
|
|
158
154
|
if filtered_params_df.empty:
|
|
159
155
|
continue
|
|
160
156
|
x_vals, y_vals = plot_with_pdims_strategy(
|
|
@@ -172,12 +168,13 @@ def plot_scaling(
|
|
|
172
168
|
y_values.extend(y_vals)
|
|
173
169
|
|
|
174
170
|
if len(x_values) != 0:
|
|
175
|
-
plotting_memory =
|
|
171
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
172
|
+
figure_title = f'{title} {fixed_size}' if title is not None else None
|
|
176
173
|
configure_axes(
|
|
177
174
|
ax,
|
|
178
175
|
x_values,
|
|
179
176
|
y_values,
|
|
180
|
-
|
|
177
|
+
figure_title,
|
|
181
178
|
xlabel,
|
|
182
179
|
plotting_memory,
|
|
183
180
|
memory_units,
|
|
@@ -187,17 +184,12 @@ def plot_scaling(
|
|
|
187
184
|
fig.delaxes(axs[i])
|
|
188
185
|
|
|
189
186
|
fig.tight_layout()
|
|
190
|
-
rect = FancyBboxPatch((0.1, 0.1),
|
|
191
|
-
0.8,
|
|
192
|
-
0.8,
|
|
193
|
-
boxstyle="round,pad=0.02",
|
|
194
|
-
ec="black",
|
|
195
|
-
fc="none")
|
|
187
|
+
rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
|
|
196
188
|
fig.patches.append(rect)
|
|
197
189
|
if output is None:
|
|
198
190
|
plt.show()
|
|
199
191
|
else:
|
|
200
|
-
plt.savefig(output
|
|
192
|
+
plt.savefig(output)
|
|
201
193
|
|
|
202
194
|
|
|
203
195
|
def plot_strong_scaling(
|
|
@@ -207,14 +199,14 @@ def plot_strong_scaling(
|
|
|
207
199
|
functions: Optional[List[str]] = None,
|
|
208
200
|
precisions: Optional[List[str]] = None,
|
|
209
201
|
pdims: Optional[List[str]] = None,
|
|
210
|
-
pdims_strategy: List[str] = [
|
|
202
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
211
203
|
print_decompositions: bool = False,
|
|
212
204
|
backends: Optional[List[str]] = None,
|
|
213
|
-
plot_columns: List[str] = [
|
|
214
|
-
memory_units: str =
|
|
215
|
-
label_text: str =
|
|
216
|
-
xlabel: str =
|
|
217
|
-
title: str =
|
|
205
|
+
plot_columns: List[str] = ['mean_time'],
|
|
206
|
+
memory_units: str = 'bytes',
|
|
207
|
+
label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
208
|
+
xlabel: str = 'Number of GPUs',
|
|
209
|
+
title: str = 'Data sizes',
|
|
218
210
|
figure_size: tuple = (6, 4),
|
|
219
211
|
dark_bg: bool = False,
|
|
220
212
|
output: Optional[str] = None,
|
|
@@ -235,14 +227,14 @@ def plot_strong_scaling(
|
|
|
235
227
|
memory_units,
|
|
236
228
|
)
|
|
237
229
|
if len(dataframes) == 0:
|
|
238
|
-
print(
|
|
230
|
+
print('No dataframes found for the given arguments. Exiting...')
|
|
239
231
|
return
|
|
240
232
|
|
|
241
233
|
plot_scaling(
|
|
242
234
|
dataframes,
|
|
243
235
|
available_data_sizes,
|
|
244
|
-
|
|
245
|
-
|
|
236
|
+
'gpus',
|
|
237
|
+
'x',
|
|
246
238
|
xlabel,
|
|
247
239
|
title,
|
|
248
240
|
figure_size,
|
|
@@ -266,14 +258,14 @@ def plot_weak_scaling(
|
|
|
266
258
|
functions: Optional[List[str]] = None,
|
|
267
259
|
precisions: Optional[List[str]] = None,
|
|
268
260
|
pdims: Optional[List[str]] = None,
|
|
269
|
-
pdims_strategy: List[str] = [
|
|
261
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
270
262
|
print_decompositions: bool = False,
|
|
271
263
|
backends: Optional[List[str]] = None,
|
|
272
|
-
plot_columns: List[str] = [
|
|
273
|
-
memory_units: str =
|
|
274
|
-
label_text: str =
|
|
275
|
-
xlabel: str =
|
|
276
|
-
title: str =
|
|
264
|
+
plot_columns: List[str] = ['mean_time'],
|
|
265
|
+
memory_units: str = 'bytes',
|
|
266
|
+
label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
267
|
+
xlabel: str = 'Data sizes',
|
|
268
|
+
title: str = 'Number of GPUs',
|
|
277
269
|
figure_size: tuple = (6, 4),
|
|
278
270
|
dark_bg: bool = False,
|
|
279
271
|
output: Optional[str] = None,
|
|
@@ -293,14 +285,14 @@ def plot_weak_scaling(
|
|
|
293
285
|
memory_units,
|
|
294
286
|
)
|
|
295
287
|
if len(dataframes) == 0:
|
|
296
|
-
print(
|
|
288
|
+
print('No dataframes found for the given arguments. Exiting...')
|
|
297
289
|
return
|
|
298
290
|
|
|
299
291
|
plot_scaling(
|
|
300
292
|
dataframes,
|
|
301
293
|
available_gpu_counts,
|
|
302
|
-
|
|
303
|
-
|
|
294
|
+
'x',
|
|
295
|
+
'gpus',
|
|
304
296
|
xlabel,
|
|
305
297
|
title,
|
|
306
298
|
figure_size,
|