jax-hpc-profiler 0.2.12__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/PKG-INFO +36 -4
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/README.md +27 -1
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/pyproject.toml +49 -4
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler/__init__.py +15 -0
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler/create_argparse.py +218 -0
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler/main.py +88 -0
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler/plotting.py +499 -0
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler/timer.py +280 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/PKG-INFO +36 -4
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/SOURCES.txt +3 -1
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler.egg-info/requires.txt +12 -0
- jax_hpc_profiler-0.3.0/tests/test_plotting.py +112 -0
- jax_hpc_profiler-0.3.0/tests/test_timer.py +103 -0
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/__init__.py +0 -9
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/create_argparse.py +0 -210
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/main.py +0 -69
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/plotting.py +0 -317
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler/timer.py +0 -289
- jax_hpc_profiler-0.2.12/src/jax_hpc_profiler.egg-info/requires.txt +0 -5
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.12 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
681
|
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
-
Keywords: jax,hpc,
|
|
682
|
+
Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
|
|
683
683
|
Classifier: Development Status :: 4 - Beta
|
|
684
684
|
Classifier: Intended Audience :: Developers
|
|
685
685
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -698,10 +698,22 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
+
Requires-Dist: adjustText
|
|
702
|
+
Requires-Dist: jax>=0.4.0
|
|
703
|
+
Requires-Dist: jaxtyping
|
|
704
|
+
Provides-Extra: test
|
|
705
|
+
Requires-Dist: pytest; extra == "test"
|
|
706
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
701
707
|
Dynamic: license-file
|
|
702
708
|
|
|
703
709
|
# JAX HPC Profiler
|
|
704
710
|
|
|
711
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
712
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
713
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
714
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
715
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
716
|
+
|
|
705
717
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
706
718
|
|
|
707
719
|
## Table of Contents
|
|
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
883
895
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
884
896
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
885
897
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
886
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
898
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
899
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
900
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
901
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
902
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
887
903
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
888
904
|
|
|
905
|
+
### Weak scaling CLI example
|
|
906
|
+
|
|
907
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
908
|
+
|
|
909
|
+
```bash
|
|
910
|
+
jax-hpc-profiler plot \
|
|
911
|
+
-f MYDATA.csv \
|
|
912
|
+
-pt mean_time \
|
|
913
|
+
-sc Weak \
|
|
914
|
+
-g 1 2 4 8 \
|
|
915
|
+
-d 32 64 128 256 \
|
|
916
|
+
--weak_ideal_line
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
920
|
+
|
|
889
921
|
## Examples
|
|
890
922
|
|
|
891
923
|
The repository includes examples for both profiling and plotting.
|
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
# JAX HPC Profiler
|
|
2
2
|
|
|
3
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
4
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
5
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
6
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
7
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
8
|
+
|
|
3
9
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
4
10
|
|
|
5
11
|
## Table of Contents
|
|
@@ -181,9 +187,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
181
187
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
182
188
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
183
189
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
184
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
190
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
191
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
192
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
193
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
194
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
185
195
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
186
196
|
|
|
197
|
+
### Weak scaling CLI example
|
|
198
|
+
|
|
199
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
200
|
+
|
|
201
|
+
```bash
|
|
202
|
+
jax-hpc-profiler plot \
|
|
203
|
+
-f MYDATA.csv \
|
|
204
|
+
-pt mean_time \
|
|
205
|
+
-sc Weak \
|
|
206
|
+
-g 1 2 4 8 \
|
|
207
|
+
-d 32 64 128 256 \
|
|
208
|
+
--weak_ideal_line
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
212
|
+
|
|
187
213
|
## Examples
|
|
188
214
|
|
|
189
215
|
The repository includes examples for both profiling and plotting.
|
|
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "jax_hpc_profiler"
|
|
7
|
-
version = "0.
|
|
8
|
-
description = "
|
|
7
|
+
version = "0.3.0"
|
|
8
|
+
description = "A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Wassim Kabalan" }
|
|
11
11
|
]
|
|
@@ -14,12 +14,15 @@ dependencies = [
|
|
|
14
14
|
"pandas",
|
|
15
15
|
"matplotlib",
|
|
16
16
|
"seaborn",
|
|
17
|
-
"tabulate"
|
|
17
|
+
"tabulate",
|
|
18
|
+
"adjustText",
|
|
19
|
+
"jax>=0.4.0",
|
|
20
|
+
"jaxtyping",
|
|
18
21
|
]
|
|
19
22
|
readme = "README.md"
|
|
20
23
|
license = { file = "LICENSE" }
|
|
21
24
|
requires-python = ">=3.8"
|
|
22
|
-
keywords = ["jax", "hpc", "
|
|
25
|
+
keywords = ["jax", "hpc", "profiling", "benchmarking", "visualization", "scaling", "performance-analysis", "gpu", "distributed-computing"]
|
|
23
26
|
|
|
24
27
|
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
|
25
28
|
classifiers = [
|
|
@@ -45,5 +48,47 @@ classifiers = [
|
|
|
45
48
|
|
|
46
49
|
urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
|
|
47
50
|
|
|
51
|
+
[project.optional-dependencies]
|
|
52
|
+
test = [
|
|
53
|
+
"pytest",
|
|
54
|
+
"pytest-cov",
|
|
55
|
+
]
|
|
56
|
+
|
|
48
57
|
[project.scripts]
|
|
49
58
|
jhp = "jax_hpc_profiler.main:main"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
[tool.ruff]
|
|
63
|
+
line-length = 100
|
|
64
|
+
fix = true # autofix issues
|
|
65
|
+
force-exclude = true # useful with ruff-pre-commit plugin
|
|
66
|
+
src = ["src"]
|
|
67
|
+
|
|
68
|
+
[tool.ruff.lint]
|
|
69
|
+
select = [
|
|
70
|
+
'ARG001', # flake8-unused-function-arguments
|
|
71
|
+
'E', # pycodestyle-errors
|
|
72
|
+
'F', # pyflakes
|
|
73
|
+
'I', # isort
|
|
74
|
+
'UP', # pyupgrade
|
|
75
|
+
'T10', # flake8-debugger
|
|
76
|
+
]
|
|
77
|
+
ignore = [
|
|
78
|
+
'E203',
|
|
79
|
+
'E731',
|
|
80
|
+
'E741',
|
|
81
|
+
'F722', # conflicts with jaxtyping Array annotations
|
|
82
|
+
"E402", # module level import not at top of file
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
[tool.ruff.lint.per-file-ignores]
|
|
86
|
+
"*.ipynb" = ["F401"]
|
|
87
|
+
|
|
88
|
+
[tool.ruff.format]
|
|
89
|
+
quote-style = 'single'
|
|
90
|
+
|
|
91
|
+
[tool.pytest.ini_options]
|
|
92
|
+
addopts = "--cov=src --cov-report=term-missing"
|
|
93
|
+
testpaths = ["tests"]
|
|
94
|
+
python_files = "test_*.py"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .create_argparse import create_argparser
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
|
+
from .timer import Timer
|
|
4
|
+
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'create_argparser',
|
|
8
|
+
'plot_strong_scaling',
|
|
9
|
+
'plot_weak_scaling',
|
|
10
|
+
'plot_weak_fixed_scaling',
|
|
11
|
+
'Timer',
|
|
12
|
+
'clean_up_csv',
|
|
13
|
+
'concatenate_csvs',
|
|
14
|
+
'plot_with_pdims_strategy',
|
|
15
|
+
]
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_argparser():
|
|
5
|
+
"""
|
|
6
|
+
Create argument parser for the HPC Plotter package.
|
|
7
|
+
|
|
8
|
+
Returns
|
|
9
|
+
-------
|
|
10
|
+
argparse.Namespace
|
|
11
|
+
Parsed and validated arguments.
|
|
12
|
+
"""
|
|
13
|
+
parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
|
|
14
|
+
|
|
15
|
+
# Group for concatenation to ensure mutually exclusive behavior
|
|
16
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
17
|
+
|
|
18
|
+
concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
|
|
19
|
+
concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
|
|
20
|
+
concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
|
|
21
|
+
|
|
22
|
+
# Arguments for plotting
|
|
23
|
+
plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
|
|
24
|
+
plot_parser.add_argument(
|
|
25
|
+
'-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
|
|
26
|
+
)
|
|
27
|
+
plot_parser.add_argument(
|
|
28
|
+
'-g',
|
|
29
|
+
'--gpus',
|
|
30
|
+
nargs='*',
|
|
31
|
+
type=int,
|
|
32
|
+
help='List of number of GPUs to plot',
|
|
33
|
+
default=None,
|
|
34
|
+
)
|
|
35
|
+
plot_parser.add_argument(
|
|
36
|
+
'-d',
|
|
37
|
+
'--data_size',
|
|
38
|
+
nargs='*',
|
|
39
|
+
type=int,
|
|
40
|
+
help='List of data sizes to plot',
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# pdims related arguments
|
|
45
|
+
plot_parser.add_argument(
|
|
46
|
+
'-fd',
|
|
47
|
+
'--filter_pdims',
|
|
48
|
+
nargs='*',
|
|
49
|
+
help='List of pdims to filter, e.g., 1x4 2x2 4x8',
|
|
50
|
+
default=None,
|
|
51
|
+
)
|
|
52
|
+
plot_parser.add_argument(
|
|
53
|
+
'-ps',
|
|
54
|
+
'--pdim_strategy',
|
|
55
|
+
choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
|
|
56
|
+
nargs='*',
|
|
57
|
+
default=['plot_fastest'],
|
|
58
|
+
help='Strategy for plotting pdims',
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Function and precision related arguments
|
|
62
|
+
plot_parser.add_argument(
|
|
63
|
+
'-pr',
|
|
64
|
+
'--precision',
|
|
65
|
+
choices=['float32', 'float64'],
|
|
66
|
+
default=['float32', 'float64'],
|
|
67
|
+
nargs='*',
|
|
68
|
+
help='Precision to filter by (float32 or float64)',
|
|
69
|
+
)
|
|
70
|
+
plot_parser.add_argument(
|
|
71
|
+
'-fn',
|
|
72
|
+
'--function_name',
|
|
73
|
+
nargs='+',
|
|
74
|
+
help='Function names to filter',
|
|
75
|
+
default=None,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Time or memory related arguments
|
|
79
|
+
plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
|
|
80
|
+
plotting_group.add_argument(
|
|
81
|
+
'-pt',
|
|
82
|
+
'--plot_times',
|
|
83
|
+
nargs='*',
|
|
84
|
+
choices=[
|
|
85
|
+
'jit_time',
|
|
86
|
+
'min_time',
|
|
87
|
+
'max_time',
|
|
88
|
+
'mean_time',
|
|
89
|
+
'std_time',
|
|
90
|
+
'last_time',
|
|
91
|
+
],
|
|
92
|
+
help='Time columns to plot',
|
|
93
|
+
)
|
|
94
|
+
plotting_group.add_argument(
|
|
95
|
+
'-pm',
|
|
96
|
+
'--plot_memory',
|
|
97
|
+
nargs='*',
|
|
98
|
+
choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
|
|
99
|
+
help='Memory columns to plot',
|
|
100
|
+
)
|
|
101
|
+
plot_parser.add_argument(
|
|
102
|
+
'-mu',
|
|
103
|
+
'--memory_units',
|
|
104
|
+
default='GB',
|
|
105
|
+
help='Memory units to plot (KB, MB, GB, TB)',
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Plot customization arguments
|
|
109
|
+
plot_parser.add_argument(
|
|
110
|
+
'-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
|
|
111
|
+
)
|
|
112
|
+
plot_parser.add_argument(
|
|
113
|
+
'-o', '--output', help='Output file (if none then only show plot)', default=None
|
|
114
|
+
)
|
|
115
|
+
plot_parser.add_argument(
|
|
116
|
+
'-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
|
|
117
|
+
)
|
|
118
|
+
plot_parser.add_argument(
|
|
119
|
+
'-pd',
|
|
120
|
+
'--print_decompositions',
|
|
121
|
+
action='store_true',
|
|
122
|
+
help='Print decompositions on plot',
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Backend related arguments
|
|
126
|
+
plot_parser.add_argument(
|
|
127
|
+
'-b',
|
|
128
|
+
'--backends',
|
|
129
|
+
nargs='*',
|
|
130
|
+
default=['MPI', 'NCCL', 'MPI4JAX'],
|
|
131
|
+
help='List of backends to include',
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Scaling type argument
|
|
135
|
+
plot_parser.add_argument(
|
|
136
|
+
'-sc',
|
|
137
|
+
'--scaling',
|
|
138
|
+
choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
|
|
139
|
+
required=True,
|
|
140
|
+
help='Scaling type (Strong, Weak, or WeakFixed)',
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Weak-scaling specific options
|
|
144
|
+
plot_parser.add_argument(
|
|
145
|
+
'--weak_ideal_line',
|
|
146
|
+
action='store_true',
|
|
147
|
+
help='Overlay an ideal flat line for weak scaling (Weak mode only)',
|
|
148
|
+
)
|
|
149
|
+
plot_parser.add_argument(
|
|
150
|
+
'--weak_reverse_axes',
|
|
151
|
+
action='store_true',
|
|
152
|
+
help=(
|
|
153
|
+
'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
|
|
154
|
+
'of data size. Requires --gpus and --data_size with equal lengths.'
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Label customization argument
|
|
159
|
+
plot_parser.add_argument(
|
|
160
|
+
'-l',
|
|
161
|
+
'--label_text',
|
|
162
|
+
type=str,
|
|
163
|
+
help=(
|
|
164
|
+
'Custom label for the plot. You can use placeholders: %%decomposition%% '
|
|
165
|
+
'(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
|
|
166
|
+
'%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
|
|
167
|
+
),
|
|
168
|
+
default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
plot_parser.add_argument(
|
|
172
|
+
'-xl',
|
|
173
|
+
'--xlabel',
|
|
174
|
+
type=str,
|
|
175
|
+
help='X-axis label for the plot',
|
|
176
|
+
)
|
|
177
|
+
plot_parser.add_argument(
|
|
178
|
+
'-tl',
|
|
179
|
+
'--title',
|
|
180
|
+
type=str,
|
|
181
|
+
help='Title for the plot',
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
subparsers.add_parser('label_help', help='Label customization help')
|
|
185
|
+
|
|
186
|
+
args = parser.parse_args()
|
|
187
|
+
|
|
188
|
+
# if command was plot, then check if pdim_strategy is validat
|
|
189
|
+
if args.command == 'plot':
|
|
190
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
191
|
+
print(
|
|
192
|
+
"""
|
|
193
|
+
Warning: 'plot_all' strategy is combined with other strategies.
|
|
194
|
+
Using 'plot_all' only.
|
|
195
|
+
"""
|
|
196
|
+
)
|
|
197
|
+
args.pdim_strategy = ['plot_all']
|
|
198
|
+
|
|
199
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
200
|
+
print(
|
|
201
|
+
"""
|
|
202
|
+
Warning: 'plot_fastest' strategy is combined with other strategies.
|
|
203
|
+
Using 'plot_fastest' only.
|
|
204
|
+
"""
|
|
205
|
+
)
|
|
206
|
+
args.pdim_strategy = ['plot_fastest']
|
|
207
|
+
if args.plot_times is not None:
|
|
208
|
+
args.plot_columns = args.plot_times
|
|
209
|
+
elif args.plot_memory is not None:
|
|
210
|
+
args.plot_columns = args.plot_memory
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
213
|
+
|
|
214
|
+
# Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
|
|
215
|
+
# data_size are provided and have matching lengths. For Strong and
|
|
216
|
+
# WeakFixed, gpus/data_size remain optional as before.
|
|
217
|
+
|
|
218
|
+
return args
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from .create_argparse import create_argparser
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
|
+
from .utils import concatenate_csvs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def main():
|
|
7
|
+
args = create_argparser()
|
|
8
|
+
|
|
9
|
+
if args.command == 'concat':
|
|
10
|
+
input_dir, output_dir = args.input, args.output
|
|
11
|
+
concatenate_csvs(input_dir, output_dir)
|
|
12
|
+
elif args.command == 'label_help':
|
|
13
|
+
print('Customize the label text for the plot. using these commands.')
|
|
14
|
+
print(' -- %m% or %methodname%: method name')
|
|
15
|
+
print(' -- %f% or %function%: function name')
|
|
16
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
17
|
+
print(' -- %pr% or %precision%: precision')
|
|
18
|
+
print(' -- %b% or %backend%: backend')
|
|
19
|
+
print(' -- %p% or %pdims%: pdims')
|
|
20
|
+
print(' -- %n% or %node%: node')
|
|
21
|
+
elif args.command == 'plot':
|
|
22
|
+
scaling = args.scaling.lower()
|
|
23
|
+
if scaling in ('weak', 'w'):
|
|
24
|
+
plot_weak_scaling(
|
|
25
|
+
args.csv_files,
|
|
26
|
+
args.gpus,
|
|
27
|
+
args.data_size,
|
|
28
|
+
args.function_name,
|
|
29
|
+
args.precision,
|
|
30
|
+
args.filter_pdims,
|
|
31
|
+
args.pdim_strategy,
|
|
32
|
+
args.print_decompositions,
|
|
33
|
+
args.backends,
|
|
34
|
+
args.plot_columns,
|
|
35
|
+
args.memory_units,
|
|
36
|
+
args.label_text,
|
|
37
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
38
|
+
args.title,
|
|
39
|
+
args.figure_size,
|
|
40
|
+
args.dark_bg,
|
|
41
|
+
args.output,
|
|
42
|
+
args.weak_ideal_line,
|
|
43
|
+
args.weak_reverse_axes,
|
|
44
|
+
)
|
|
45
|
+
elif scaling in ('strong', 's'):
|
|
46
|
+
plot_strong_scaling(
|
|
47
|
+
args.csv_files,
|
|
48
|
+
args.gpus,
|
|
49
|
+
args.data_size,
|
|
50
|
+
args.function_name,
|
|
51
|
+
args.precision,
|
|
52
|
+
args.filter_pdims,
|
|
53
|
+
args.pdim_strategy,
|
|
54
|
+
args.print_decompositions,
|
|
55
|
+
args.backends,
|
|
56
|
+
args.plot_columns,
|
|
57
|
+
args.memory_units,
|
|
58
|
+
args.label_text,
|
|
59
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
60
|
+
args.title if getattr(args, 'title', None) is not None else 'Data sizes',
|
|
61
|
+
args.figure_size,
|
|
62
|
+
args.dark_bg,
|
|
63
|
+
args.output,
|
|
64
|
+
)
|
|
65
|
+
elif scaling in ('weakfixed', 'wf'):
|
|
66
|
+
plot_weak_fixed_scaling(
|
|
67
|
+
args.csv_files,
|
|
68
|
+
args.gpus,
|
|
69
|
+
args.data_size,
|
|
70
|
+
args.function_name,
|
|
71
|
+
args.precision,
|
|
72
|
+
args.filter_pdims,
|
|
73
|
+
args.pdim_strategy,
|
|
74
|
+
args.print_decompositions,
|
|
75
|
+
args.backends,
|
|
76
|
+
args.plot_columns,
|
|
77
|
+
args.memory_units,
|
|
78
|
+
args.label_text,
|
|
79
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
|
|
80
|
+
args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
|
|
81
|
+
args.figure_size,
|
|
82
|
+
args.dark_bg,
|
|
83
|
+
args.output,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if __name__ == '__main__':
|
|
88
|
+
main()
|