jax-hpc-profiler 0.2.13__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/PKG-INFO +36 -4
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/README.md +27 -1
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/pyproject.toml +22 -4
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/__init__.py +2 -1
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/create_argparse.py +21 -2
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/main.py +28 -5
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/plotting.py +191 -1
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/timer.py +12 -8
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/PKG-INFO +36 -4
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/SOURCES.txt +3 -1
- jax_hpc_profiler-0.3.0/src/jax_hpc_profiler.egg-info/requires.txt +12 -0
- jax_hpc_profiler-0.3.0/tests/test_plotting.py +112 -0
- jax_hpc_profiler-0.3.0/tests/test_timer.py +103 -0
- jax_hpc_profiler-0.2.13/src/jax_hpc_profiler.egg-info/requires.txt +0 -5
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/utils.py +0 -0
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
681
|
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
-
Keywords: jax,hpc,
|
|
682
|
+
Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
|
|
683
683
|
Classifier: Development Status :: 4 - Beta
|
|
684
684
|
Classifier: Intended Audience :: Developers
|
|
685
685
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -698,10 +698,22 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
+
Requires-Dist: adjustText
|
|
702
|
+
Requires-Dist: jax>=0.4.0
|
|
703
|
+
Requires-Dist: jaxtyping
|
|
704
|
+
Provides-Extra: test
|
|
705
|
+
Requires-Dist: pytest; extra == "test"
|
|
706
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
701
707
|
Dynamic: license-file
|
|
702
708
|
|
|
703
709
|
# JAX HPC Profiler
|
|
704
710
|
|
|
711
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
712
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
713
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
714
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
715
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
716
|
+
|
|
705
717
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
706
718
|
|
|
707
719
|
## Table of Contents
|
|
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
883
895
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
884
896
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
885
897
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
886
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
898
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
899
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
900
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
901
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
902
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
887
903
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
888
904
|
|
|
905
|
+
### Weak scaling CLI example
|
|
906
|
+
|
|
907
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
908
|
+
|
|
909
|
+
```bash
|
|
910
|
+
jax-hpc-profiler plot \
|
|
911
|
+
-f MYDATA.csv \
|
|
912
|
+
-pt mean_time \
|
|
913
|
+
-sc Weak \
|
|
914
|
+
-g 1 2 4 8 \
|
|
915
|
+
-d 32 64 128 256 \
|
|
916
|
+
--weak_ideal_line
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
920
|
+
|
|
889
921
|
## Examples
|
|
890
922
|
|
|
891
923
|
The repository includes examples for both profiling and plotting.
|
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
# JAX HPC Profiler
|
|
2
2
|
|
|
3
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
4
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
5
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
6
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
7
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
8
|
+
|
|
3
9
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
4
10
|
|
|
5
11
|
## Table of Contents
|
|
@@ -181,9 +187,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
181
187
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
182
188
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
183
189
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
184
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
190
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
191
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
192
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
193
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
194
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
185
195
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
186
196
|
|
|
197
|
+
### Weak scaling CLI example
|
|
198
|
+
|
|
199
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
200
|
+
|
|
201
|
+
```bash
|
|
202
|
+
jax-hpc-profiler plot \
|
|
203
|
+
-f MYDATA.csv \
|
|
204
|
+
-pt mean_time \
|
|
205
|
+
-sc Weak \
|
|
206
|
+
-g 1 2 4 8 \
|
|
207
|
+
-d 32 64 128 256 \
|
|
208
|
+
--weak_ideal_line
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
212
|
+
|
|
187
213
|
## Examples
|
|
188
214
|
|
|
189
215
|
The repository includes examples for both profiling and plotting.
|
|
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "jax_hpc_profiler"
|
|
7
|
-
version = "0.
|
|
8
|
-
description = "
|
|
7
|
+
version = "0.3.0"
|
|
8
|
+
description = "A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Wassim Kabalan" }
|
|
11
11
|
]
|
|
@@ -14,12 +14,15 @@ dependencies = [
|
|
|
14
14
|
"pandas",
|
|
15
15
|
"matplotlib",
|
|
16
16
|
"seaborn",
|
|
17
|
-
"tabulate"
|
|
17
|
+
"tabulate",
|
|
18
|
+
"adjustText",
|
|
19
|
+
"jax>=0.4.0",
|
|
20
|
+
"jaxtyping",
|
|
18
21
|
]
|
|
19
22
|
readme = "README.md"
|
|
20
23
|
license = { file = "LICENSE" }
|
|
21
24
|
requires-python = ">=3.8"
|
|
22
|
-
keywords = ["jax", "hpc", "
|
|
25
|
+
keywords = ["jax", "hpc", "profiling", "benchmarking", "visualization", "scaling", "performance-analysis", "gpu", "distributed-computing"]
|
|
23
26
|
|
|
24
27
|
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
|
25
28
|
classifiers = [
|
|
@@ -45,6 +48,12 @@ classifiers = [
|
|
|
45
48
|
|
|
46
49
|
urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
|
|
47
50
|
|
|
51
|
+
[project.optional-dependencies]
|
|
52
|
+
test = [
|
|
53
|
+
"pytest",
|
|
54
|
+
"pytest-cov",
|
|
55
|
+
]
|
|
56
|
+
|
|
48
57
|
[project.scripts]
|
|
49
58
|
jhp = "jax_hpc_profiler.main:main"
|
|
50
59
|
|
|
@@ -70,7 +79,16 @@ ignore = [
|
|
|
70
79
|
'E731',
|
|
71
80
|
'E741',
|
|
72
81
|
'F722', # conflicts with jaxtyping Array annotations
|
|
82
|
+
"E402", # module level import not at top of file
|
|
73
83
|
]
|
|
74
84
|
|
|
85
|
+
[tool.ruff.lint.per-file-ignores]
|
|
86
|
+
"*.ipynb" = ["F401"]
|
|
87
|
+
|
|
75
88
|
[tool.ruff.format]
|
|
76
89
|
quote-style = 'single'
|
|
90
|
+
|
|
91
|
+
[tool.pytest.ini_options]
|
|
92
|
+
addopts = "--cov=src --cov-report=term-missing"
|
|
93
|
+
testpaths = ["tests"]
|
|
94
|
+
python_files = "test_*.py"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .create_argparse import create_argparser
|
|
2
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
3
|
from .timer import Timer
|
|
4
4
|
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
5
|
|
|
@@ -7,6 +7,7 @@ __all__ = [
|
|
|
7
7
|
'create_argparser',
|
|
8
8
|
'plot_strong_scaling',
|
|
9
9
|
'plot_weak_scaling',
|
|
10
|
+
'plot_weak_fixed_scaling',
|
|
10
11
|
'Timer',
|
|
11
12
|
'clean_up_csv',
|
|
12
13
|
'concatenate_csvs',
|
|
@@ -135,9 +135,24 @@ def create_argparser():
|
|
|
135
135
|
plot_parser.add_argument(
|
|
136
136
|
'-sc',
|
|
137
137
|
'--scaling',
|
|
138
|
-
choices=['Weak', 'Strong', 'w', 's'],
|
|
138
|
+
choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
|
|
139
139
|
required=True,
|
|
140
|
-
help='Scaling type (Weak or
|
|
140
|
+
help='Scaling type (Strong, Weak, or WeakFixed)',
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Weak-scaling specific options
|
|
144
|
+
plot_parser.add_argument(
|
|
145
|
+
'--weak_ideal_line',
|
|
146
|
+
action='store_true',
|
|
147
|
+
help='Overlay an ideal flat line for weak scaling (Weak mode only)',
|
|
148
|
+
)
|
|
149
|
+
plot_parser.add_argument(
|
|
150
|
+
'--weak_reverse_axes',
|
|
151
|
+
action='store_true',
|
|
152
|
+
help=(
|
|
153
|
+
'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
|
|
154
|
+
'of data size. Requires --gpus and --data_size with equal lengths.'
|
|
155
|
+
),
|
|
141
156
|
)
|
|
142
157
|
|
|
143
158
|
# Label customization argument
|
|
@@ -196,4 +211,8 @@ def create_argparser():
|
|
|
196
211
|
else:
|
|
197
212
|
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
198
213
|
|
|
214
|
+
# Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
|
|
215
|
+
# data_size are provided and have matching lengths. For Strong and
|
|
216
|
+
# WeakFixed, gpus/data_size remain optional as before.
|
|
217
|
+
|
|
199
218
|
return args
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .create_argparse import create_argparser
|
|
2
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
3
|
from .utils import concatenate_csvs
|
|
4
4
|
|
|
5
5
|
|
|
@@ -19,7 +19,8 @@ def main():
|
|
|
19
19
|
print(' -- %p% or %pdims%: pdims')
|
|
20
20
|
print(' -- %n% or %node%: node')
|
|
21
21
|
elif args.command == 'plot':
|
|
22
|
-
|
|
22
|
+
scaling = args.scaling.lower()
|
|
23
|
+
if scaling in ('weak', 'w'):
|
|
23
24
|
plot_weak_scaling(
|
|
24
25
|
args.csv_files,
|
|
25
26
|
args.gpus,
|
|
@@ -33,13 +34,15 @@ def main():
|
|
|
33
34
|
args.plot_columns,
|
|
34
35
|
args.memory_units,
|
|
35
36
|
args.label_text,
|
|
37
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
36
38
|
args.title,
|
|
37
|
-
args.label_text,
|
|
38
39
|
args.figure_size,
|
|
39
40
|
args.dark_bg,
|
|
40
41
|
args.output,
|
|
42
|
+
args.weak_ideal_line,
|
|
43
|
+
args.weak_reverse_axes,
|
|
41
44
|
)
|
|
42
|
-
elif
|
|
45
|
+
elif scaling in ('strong', 's'):
|
|
43
46
|
plot_strong_scaling(
|
|
44
47
|
args.csv_files,
|
|
45
48
|
args.gpus,
|
|
@@ -53,8 +56,28 @@ def main():
|
|
|
53
56
|
args.plot_columns,
|
|
54
57
|
args.memory_units,
|
|
55
58
|
args.label_text,
|
|
56
|
-
args.
|
|
59
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
60
|
+
args.title if getattr(args, 'title', None) is not None else 'Data sizes',
|
|
61
|
+
args.figure_size,
|
|
62
|
+
args.dark_bg,
|
|
63
|
+
args.output,
|
|
64
|
+
)
|
|
65
|
+
elif scaling in ('weakfixed', 'wf'):
|
|
66
|
+
plot_weak_fixed_scaling(
|
|
67
|
+
args.csv_files,
|
|
68
|
+
args.gpus,
|
|
69
|
+
args.data_size,
|
|
70
|
+
args.function_name,
|
|
71
|
+
args.precision,
|
|
72
|
+
args.filter_pdims,
|
|
73
|
+
args.pdim_strategy,
|
|
74
|
+
args.print_decompositions,
|
|
75
|
+
args.backends,
|
|
76
|
+
args.plot_columns,
|
|
77
|
+
args.memory_units,
|
|
57
78
|
args.label_text,
|
|
79
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
|
|
80
|
+
args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
|
|
58
81
|
args.figure_size,
|
|
59
82
|
args.dark_bg,
|
|
60
83
|
args.output,
|
|
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
from adjustText import adjust_text
|
|
7
8
|
from matplotlib.axes import Axes
|
|
8
9
|
from matplotlib.patches import FancyBboxPatch
|
|
9
10
|
|
|
@@ -252,6 +253,195 @@ def plot_strong_scaling(
|
|
|
252
253
|
|
|
253
254
|
|
|
254
255
|
def plot_weak_scaling(
|
|
256
|
+
csv_files: List[str],
|
|
257
|
+
fixed_gpu_size: Optional[List[int]] = None,
|
|
258
|
+
fixed_data_size: Optional[List[int]] = None,
|
|
259
|
+
functions: Optional[List[str]] = None,
|
|
260
|
+
precisions: Optional[List[str]] = None,
|
|
261
|
+
pdims: Optional[List[str]] = None,
|
|
262
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
263
|
+
print_decompositions: bool = False,
|
|
264
|
+
backends: Optional[List[str]] = None,
|
|
265
|
+
plot_columns: List[str] = ['mean_time'],
|
|
266
|
+
memory_units: str = 'bytes',
|
|
267
|
+
label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
268
|
+
xlabel: str = 'Number of GPUs',
|
|
269
|
+
title: Optional[str] = None,
|
|
270
|
+
figure_size: tuple = (6, 4),
|
|
271
|
+
dark_bg: bool = False,
|
|
272
|
+
output: Optional[str] = None,
|
|
273
|
+
ideal_line: bool = False,
|
|
274
|
+
reverse_axes: bool = False,
|
|
275
|
+
):
|
|
276
|
+
"""
|
|
277
|
+
Plot true weak scaling: runtime vs GPUs for explicit (gpus, data size) sequences.
|
|
278
|
+
|
|
279
|
+
Both ``fixed_gpu_size`` and ``fixed_data_size`` must be provided and have the same length,
|
|
280
|
+
representing explicit weak-scaling pairs (gpus[i], data_size[i]).
|
|
281
|
+
|
|
282
|
+
reverse_axes:
|
|
283
|
+
- False (default): x-axis is GPUs, y-axis is time; points are annotated with
|
|
284
|
+
``N=<data_size>``.
|
|
285
|
+
- True: x-axis is data size, y-axis is time; points are annotated with ``GPUs=<gpu_count>``.
|
|
286
|
+
"""
|
|
287
|
+
if fixed_gpu_size is None or fixed_data_size is None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
'Weak scaling requires both fixed_gpu_size (gpus) and fixed_data_size (problem sizes).'
|
|
290
|
+
)
|
|
291
|
+
if len(fixed_gpu_size) != len(fixed_data_size):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
'Weak scaling requires fixed_gpu_size and fixed_data_size lists of equal length.'
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
gpu_to_data = {int(g): int(d) for g, d in zip(fixed_gpu_size, fixed_data_size)}
|
|
297
|
+
data_to_gpu = {int(d): int(g) for g, d in zip(fixed_gpu_size, fixed_data_size)}
|
|
298
|
+
x_col = 'x' if reverse_axes else 'gpus'
|
|
299
|
+
|
|
300
|
+
dataframes, _, _ = clean_up_csv(
|
|
301
|
+
csv_files,
|
|
302
|
+
precisions,
|
|
303
|
+
functions,
|
|
304
|
+
fixed_gpu_size,
|
|
305
|
+
fixed_data_size,
|
|
306
|
+
pdims,
|
|
307
|
+
pdims_strategy,
|
|
308
|
+
backends,
|
|
309
|
+
memory_units,
|
|
310
|
+
)
|
|
311
|
+
if len(dataframes) == 0:
|
|
312
|
+
print('No dataframes found for the given arguments. Exiting...')
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
if dark_bg:
|
|
316
|
+
plt.style.use('dark_background')
|
|
317
|
+
|
|
318
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
319
|
+
|
|
320
|
+
x_values: List[float] = []
|
|
321
|
+
y_values: List[float] = []
|
|
322
|
+
annotations: List = []
|
|
323
|
+
ideal_line_plotted = False
|
|
324
|
+
|
|
325
|
+
for method, df in dataframes.items():
|
|
326
|
+
# Determine parameter sets from the filtered dataframe if not provided
|
|
327
|
+
local_functions = pd.unique(df['function']) if functions is None else functions
|
|
328
|
+
local_precisions = pd.unique(df['precision']) if precisions is None else precisions
|
|
329
|
+
local_backends = pd.unique(df['backend']) if backends is None else backends
|
|
330
|
+
|
|
331
|
+
combinations = product(local_backends, local_precisions, local_functions, plot_columns)
|
|
332
|
+
|
|
333
|
+
for backend, precision, function, plot_column in combinations:
|
|
334
|
+
base_df = df[
|
|
335
|
+
(df['backend'] == backend)
|
|
336
|
+
& (df['precision'] == precision)
|
|
337
|
+
& (df['function'] == function)
|
|
338
|
+
]
|
|
339
|
+
if base_df.empty:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# Keep only rows matching any of the (gpus, x) pairs
|
|
343
|
+
mask = pd.Series(False, index=base_df.index)
|
|
344
|
+
for g, d in zip(fixed_gpu_size, fixed_data_size):
|
|
345
|
+
mask |= (base_df['gpus'] == int(g)) & (base_df['x'] == int(d))
|
|
346
|
+
|
|
347
|
+
filtered_params_df = base_df[mask]
|
|
348
|
+
if filtered_params_df.empty:
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
x_vals, y_vals = plot_with_pdims_strategy(
|
|
352
|
+
ax,
|
|
353
|
+
filtered_params_df,
|
|
354
|
+
method,
|
|
355
|
+
pdims_strategy,
|
|
356
|
+
print_decompositions,
|
|
357
|
+
x_col,
|
|
358
|
+
plot_column,
|
|
359
|
+
label_text,
|
|
360
|
+
)
|
|
361
|
+
if x_vals is None or len(x_vals) == 0:
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
x_arr = np.asarray(x_vals).reshape(-1)
|
|
365
|
+
y_arr = np.asarray(y_vals).reshape(-1)
|
|
366
|
+
|
|
367
|
+
# Annotate every point with data size or GPU count depending on axis choice.
|
|
368
|
+
# Use plain data coordinates for the text; adjust_text will then only move
|
|
369
|
+
# the labels slightly (mostly vertically) to avoid overlap.
|
|
370
|
+
for xv, yv in zip(x_arr, y_arr):
|
|
371
|
+
if reverse_axes:
|
|
372
|
+
gpu = data_to_gpu.get(int(xv))
|
|
373
|
+
if gpu is None:
|
|
374
|
+
continue
|
|
375
|
+
label = f'GPUs={gpu}'
|
|
376
|
+
else:
|
|
377
|
+
data_size = gpu_to_data.get(int(xv))
|
|
378
|
+
if data_size is None:
|
|
379
|
+
continue
|
|
380
|
+
label = f'N={data_size}'
|
|
381
|
+
|
|
382
|
+
text_obj = ax.text(
|
|
383
|
+
float(xv),
|
|
384
|
+
float(yv),
|
|
385
|
+
label,
|
|
386
|
+
ha='center',
|
|
387
|
+
va='bottom',
|
|
388
|
+
fontsize='small',
|
|
389
|
+
clip_on=True,
|
|
390
|
+
)
|
|
391
|
+
annotations.append(text_obj)
|
|
392
|
+
|
|
393
|
+
x_values.extend(x_arr.tolist())
|
|
394
|
+
y_values.extend(y_arr.tolist())
|
|
395
|
+
|
|
396
|
+
if ideal_line and not ideal_line_plotted:
|
|
397
|
+
# Use the smallest x value in this curve as baseline
|
|
398
|
+
baseline_index = np.argmin(x_arr)
|
|
399
|
+
baseline_y = y_arr[baseline_index]
|
|
400
|
+
ax.hlines(
|
|
401
|
+
baseline_y,
|
|
402
|
+
xmin=float(np.min(x_arr)),
|
|
403
|
+
xmax=float(np.max(x_arr)),
|
|
404
|
+
colors='gray',
|
|
405
|
+
linestyles='dashed',
|
|
406
|
+
label='Ideal weak scaling',
|
|
407
|
+
)
|
|
408
|
+
ideal_line_plotted = True
|
|
409
|
+
y_values.append(float(baseline_y))
|
|
410
|
+
|
|
411
|
+
if x_values:
|
|
412
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
413
|
+
figure_title = title if title is not None else 'Weak scaling'
|
|
414
|
+
configure_axes(
|
|
415
|
+
ax,
|
|
416
|
+
x_values,
|
|
417
|
+
y_values,
|
|
418
|
+
figure_title,
|
|
419
|
+
xlabel,
|
|
420
|
+
plotting_memory,
|
|
421
|
+
memory_units,
|
|
422
|
+
)
|
|
423
|
+
if annotations:
|
|
424
|
+
ax.figure.canvas.draw()
|
|
425
|
+
adjust_text(
|
|
426
|
+
annotations,
|
|
427
|
+
ax=ax,
|
|
428
|
+
# keep points aligned in x, only allow vertical motion
|
|
429
|
+
only_move={'text': 'y', 'static': 'y'},
|
|
430
|
+
expand=(1.02, 1.05),
|
|
431
|
+
force_text=(0.08, 0.2),
|
|
432
|
+
max_move=(0, 30),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
fig.tight_layout()
|
|
436
|
+
rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
|
|
437
|
+
fig.patches.append(rect)
|
|
438
|
+
if output is None:
|
|
439
|
+
plt.show()
|
|
440
|
+
else:
|
|
441
|
+
plt.savefig(output)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def plot_weak_fixed_scaling(
|
|
255
445
|
csv_files: List[str],
|
|
256
446
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
257
447
|
fixed_data_size: Optional[List[int]] = None,
|
|
@@ -271,7 +461,7 @@ def plot_weak_scaling(
|
|
|
271
461
|
output: Optional[str] = None,
|
|
272
462
|
):
|
|
273
463
|
"""
|
|
274
|
-
Plot
|
|
464
|
+
Plot size scaling at fixed GPU count (previous weak-scaling behavior).
|
|
275
465
|
"""
|
|
276
466
|
dataframes, available_gpu_counts, _ = clean_up_csv(
|
|
277
467
|
csv_files,
|
|
@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
-
from jax import make_jaxpr
|
|
10
|
-
from jax.experimental.shard_map import shard_map
|
|
9
|
+
from jax import make_jaxpr, shard_map
|
|
11
10
|
from jax.sharding import NamedSharding
|
|
12
11
|
from jax.sharding import PartitionSpec as P
|
|
13
12
|
from jaxtyping import Array
|
|
@@ -25,8 +24,13 @@ class Timer:
|
|
|
25
24
|
):
|
|
26
25
|
self.jit_time = 0.0
|
|
27
26
|
self.times = []
|
|
28
|
-
self.profiling_data = {
|
|
29
|
-
|
|
27
|
+
self.profiling_data = {
|
|
28
|
+
'generated_code': 'N/A',
|
|
29
|
+
'argument_size': 'N/A',
|
|
30
|
+
'output_size': 'N/A',
|
|
31
|
+
'temp_size': 'N/A',
|
|
32
|
+
}
|
|
33
|
+
self.compiled_code = {'JAXPR': 'N/A', 'LOWERED': 'N/A', 'COMPILED': 'N/A'}
|
|
30
34
|
self.save_jaxpr = save_jaxpr
|
|
31
35
|
self.compile_info = compile_info
|
|
32
36
|
self.jax_fn = jax_fn
|
|
@@ -181,10 +185,10 @@ class Timer:
|
|
|
181
185
|
mean_time = np.mean(times_array)
|
|
182
186
|
std_time = np.std(times_array)
|
|
183
187
|
last_time = times_array[-1]
|
|
184
|
-
generated_code = self.profiling_data
|
|
185
|
-
argument_size = self.profiling_data
|
|
186
|
-
output_size = self.profiling_data
|
|
187
|
-
temp_size = self.profiling_data
|
|
188
|
+
generated_code = self.profiling_data.get('generated_code', 'N/A')
|
|
189
|
+
argument_size = self.profiling_data.get('argument_size', 'N/A')
|
|
190
|
+
output_size = self.profiling_data.get('output_size', 'N/A')
|
|
191
|
+
temp_size = self.profiling_data.get('temp_size', 'N/A')
|
|
188
192
|
|
|
189
193
|
csv_line = (
|
|
190
194
|
f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
681
|
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
-
Keywords: jax,hpc,
|
|
682
|
+
Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
|
|
683
683
|
Classifier: Development Status :: 4 - Beta
|
|
684
684
|
Classifier: Intended Audience :: Developers
|
|
685
685
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -698,10 +698,22 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
+
Requires-Dist: adjustText
|
|
702
|
+
Requires-Dist: jax>=0.4.0
|
|
703
|
+
Requires-Dist: jaxtyping
|
|
704
|
+
Provides-Extra: test
|
|
705
|
+
Requires-Dist: pytest; extra == "test"
|
|
706
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
701
707
|
Dynamic: license-file
|
|
702
708
|
|
|
703
709
|
# JAX HPC Profiler
|
|
704
710
|
|
|
711
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
712
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
713
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
714
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
715
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
716
|
+
|
|
705
717
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
706
718
|
|
|
707
719
|
## Table of Contents
|
|
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
883
895
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
884
896
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
885
897
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
886
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
898
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
899
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
900
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
901
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
902
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
887
903
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
888
904
|
|
|
905
|
+
### Weak scaling CLI example
|
|
906
|
+
|
|
907
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
908
|
+
|
|
909
|
+
```bash
|
|
910
|
+
jax-hpc-profiler plot \
|
|
911
|
+
-f MYDATA.csv \
|
|
912
|
+
-pt mean_time \
|
|
913
|
+
-sc Weak \
|
|
914
|
+
-g 1 2 4 8 \
|
|
915
|
+
-d 32 64 128 256 \
|
|
916
|
+
--weak_ideal_line
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
920
|
+
|
|
889
921
|
## Examples
|
|
890
922
|
|
|
891
923
|
The repository includes examples for both profiling and plotting.
|
{jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/SOURCES.txt
RENAMED
|
@@ -12,4 +12,6 @@ src/jax_hpc_profiler.egg-info/SOURCES.txt
|
|
|
12
12
|
src/jax_hpc_profiler.egg-info/dependency_links.txt
|
|
13
13
|
src/jax_hpc_profiler.egg-info/entry_points.txt
|
|
14
14
|
src/jax_hpc_profiler.egg-info/requires.txt
|
|
15
|
-
src/jax_hpc_profiler.egg-info/top_level.txt
|
|
15
|
+
src/jax_hpc_profiler.egg-info/top_level.txt
|
|
16
|
+
tests/test_plotting.py
|
|
17
|
+
tests/test_timer.py
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from jax_hpc_profiler.plotting import (
|
|
6
|
+
configure_axes,
|
|
7
|
+
plot_strong_scaling,
|
|
8
|
+
plot_weak_fixed_scaling,
|
|
9
|
+
plot_weak_scaling,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_plt():
|
|
15
|
+
with patch('jax_hpc_profiler.plotting.plt') as mock:
|
|
16
|
+
mock_fig = MagicMock()
|
|
17
|
+
mock_ax = MagicMock()
|
|
18
|
+
mock.subplots.return_value = (mock_fig, mock_ax)
|
|
19
|
+
yield mock
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def sample_csv(tmp_path):
|
|
24
|
+
csv_file = tmp_path / 'test_data.csv'
|
|
25
|
+
# Create data matching Timer.report format (19 columns)
|
|
26
|
+
# function,precision,x,y,z,px,py,backend,nodes,jit,min,max,mean,std,last,gen_code,arg_size,
|
|
27
|
+
# out_size,tmp_size
|
|
28
|
+
data = [
|
|
29
|
+
'fun1,float32,100,100,100,1,1,NCCL,1,0.1,1.0,1.2,1.1,0.01,1.1,1000,1000,1000,1000',
|
|
30
|
+
'fun1,float32,200,200,200,1,1,NCCL,1,0.2,2.0,2.4,2.2,0.02,2.2,2000,2000,2000,2000',
|
|
31
|
+
'fun1,float32,400,400,400,1,1,NCCL,1,0.4,4.0,4.8,4.4,0.04,4.4,4000,4000,4000,4000',
|
|
32
|
+
# Add entries for strong scaling (same x, diff nodes/gpus)
|
|
33
|
+
# nodes is used as 'gpus' in some logic?
|
|
34
|
+
# utils.py: df['gpus'] = df['px'] * df['py']
|
|
35
|
+
# So we vary px*py.
|
|
36
|
+
'fun2,float32,1000,1000,1000,1,1,NCCL,1,0.1,10.0,12.0,11.0,0.1,11.0,1000,1000,1000,1000',
|
|
37
|
+
'fun2,float32,1000,1000,1000,2,1,NCCL,2,0.1,5.0,6.0,5.5,0.05,5.5,1000,1000,1000,1000',
|
|
38
|
+
'fun2,float32,1000,1000,1000,2,2,NCCL,4,0.1,2.5,3.0,2.75,0.025,2.75,1000,1000,1000,1000',
|
|
39
|
+
]
|
|
40
|
+
with open(csv_file, 'w') as f:
|
|
41
|
+
f.write('\n'.join(data) + '\n')
|
|
42
|
+
return str(csv_file)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def mock_adjust_text():
|
|
47
|
+
with patch('jax_hpc_profiler.plotting.adjust_text') as mock:
|
|
48
|
+
yield mock
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_plot_weak_fixed_scaling(mock_plt, sample_csv):
|
|
52
|
+
# WeakFixed: vary data size (x), fixed GPUs (calculated from px*py).
|
|
53
|
+
# In sample_csv fun1: gpus=1, x=[100, 200, 400]
|
|
54
|
+
|
|
55
|
+
plot_weak_fixed_scaling(
|
|
56
|
+
csv_files=[sample_csv],
|
|
57
|
+
fixed_gpu_size=[1],
|
|
58
|
+
fixed_data_size=[100, 200, 400],
|
|
59
|
+
functions=['fun1'],
|
|
60
|
+
xlabel='Data Size',
|
|
61
|
+
title='Weak Fixed Scaling',
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
assert mock_plt.show.called or mock_plt.savefig.called
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def test_plot_strong_scaling(mock_plt, sample_csv):
|
|
68
|
+
# Strong: fixed data size (x), vary GPUs.
|
|
69
|
+
# In sample_csv fun2: x=1000, gpus=[1, 2, 4]
|
|
70
|
+
|
|
71
|
+
plot_strong_scaling(
|
|
72
|
+
csv_files=[sample_csv],
|
|
73
|
+
fixed_data_size=[1000],
|
|
74
|
+
fixed_gpu_size=[1, 2, 4],
|
|
75
|
+
functions=['fun2'],
|
|
76
|
+
xlabel='GPUs',
|
|
77
|
+
title='Strong Scaling',
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
assert mock_plt.show.called or mock_plt.savefig.called
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_plot_weak_scaling(mock_plt, mock_adjust_text, sample_csv):
|
|
84
|
+
# Weak: explicit pairs of (gpus, data_size)
|
|
85
|
+
# We have (1, 100) for fun1.
|
|
86
|
+
|
|
87
|
+
plot_weak_scaling(
|
|
88
|
+
csv_files=[sample_csv],
|
|
89
|
+
fixed_gpu_size=[1],
|
|
90
|
+
fixed_data_size=[100],
|
|
91
|
+
functions=['fun1'],
|
|
92
|
+
xlabel='GPUs',
|
|
93
|
+
title='Weak Scaling',
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
assert mock_plt.show.called or mock_plt.savefig.called
|
|
97
|
+
assert (
|
|
98
|
+
mock_adjust_text.called or not mock_adjust_text.called
|
|
99
|
+
) # It's okay if called or not, just don't crash.
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_configure_axes():
|
|
103
|
+
# Test the helper directly
|
|
104
|
+
mock_ax = MagicMock()
|
|
105
|
+
configure_axes(
|
|
106
|
+
mock_ax, x_values=[1, 2, 4], y_values=[10, 5, 2.5], title='Test Plot', xlabel='X Label'
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
mock_ax.set_title.assert_called_with('Test Plot')
|
|
110
|
+
mock_ax.set_xlabel.assert_called_with('X Label')
|
|
111
|
+
mock_ax.set_xscale.assert_called()
|
|
112
|
+
mock_ax.set_yscale.assert_called_with('symlog')
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from jax_hpc_profiler import Timer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Simple JAX function
|
|
9
|
+
@jax.jit
|
|
10
|
+
def simple_add(x, y):
|
|
11
|
+
return x + y
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_timer_initialization():
|
|
15
|
+
timer = Timer()
|
|
16
|
+
assert timer.save_jaxpr is False
|
|
17
|
+
assert timer.compile_info is True
|
|
18
|
+
assert timer.jit_time == 0.0
|
|
19
|
+
assert len(timer.times) == 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_chrono_jit():
|
|
23
|
+
timer = Timer(save_jaxpr=True, compile_info=True)
|
|
24
|
+
x = jnp.ones((10, 10))
|
|
25
|
+
y = jnp.ones((10, 10))
|
|
26
|
+
|
|
27
|
+
out = timer.chrono_jit(simple_add, x, y)
|
|
28
|
+
assert out.shape == (10, 10)
|
|
29
|
+
assert timer.jit_time > 0
|
|
30
|
+
assert timer.compiled_code['JAXPR'] != 'N/A'
|
|
31
|
+
assert timer.compiled_code['COMPILED'] != 'N/A'
|
|
32
|
+
assert timer.profiling_data['generated_code'] != 'N/A'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_chrono_fun():
|
|
36
|
+
timer = Timer()
|
|
37
|
+
x = jnp.ones((10, 10))
|
|
38
|
+
y = jnp.ones((10, 10))
|
|
39
|
+
|
|
40
|
+
# Run once to compile (if using jit externally, but here we just call the function)
|
|
41
|
+
out = timer.chrono_fun(simple_add, x, y)
|
|
42
|
+
assert out.shape == (10, 10)
|
|
43
|
+
assert len(timer.times) == 1
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_report(tmp_path):
|
|
47
|
+
timer = Timer(save_jaxpr=False)
|
|
48
|
+
x = jnp.ones((10, 10))
|
|
49
|
+
y = jnp.ones((10, 10))
|
|
50
|
+
|
|
51
|
+
timer.chrono_jit(simple_add, x, y)
|
|
52
|
+
for _ in range(5):
|
|
53
|
+
timer.chrono_fun(simple_add, x, y)
|
|
54
|
+
|
|
55
|
+
csv_file = tmp_path / 'report.csv'
|
|
56
|
+
md_file = tmp_path / 'report.md'
|
|
57
|
+
|
|
58
|
+
# We specify nodes=1 explicitely to match what we expect in some tests, though default is 1
|
|
59
|
+
timer.report(
|
|
60
|
+
str(csv_file),
|
|
61
|
+
function='simple_add',
|
|
62
|
+
x=10,
|
|
63
|
+
y=10,
|
|
64
|
+
precision='float32',
|
|
65
|
+
md_filename=str(md_file),
|
|
66
|
+
extra_info={'custom_key': 'custom_val'},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
assert csv_file.exists()
|
|
70
|
+
assert md_file.exists()
|
|
71
|
+
|
|
72
|
+
with open(csv_file) as f:
|
|
73
|
+
content = f.read()
|
|
74
|
+
assert 'simple_add' in content
|
|
75
|
+
assert 'float32' in content
|
|
76
|
+
|
|
77
|
+
with open(md_file) as f:
|
|
78
|
+
content = f.read()
|
|
79
|
+
assert '# Reporting for simple_add' in content
|
|
80
|
+
assert 'custom_key' in content
|
|
81
|
+
assert 'custom_val' in content
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_normalize_memory_units():
|
|
85
|
+
timer = Timer()
|
|
86
|
+
# Mocking internal state as if we had data
|
|
87
|
+
timer.jax_fn = True
|
|
88
|
+
timer.compile_info = True
|
|
89
|
+
|
|
90
|
+
assert timer._normalize_memory_units(100) == '100.00 B'
|
|
91
|
+
assert timer._normalize_memory_units(1024) == '1.00 KB'
|
|
92
|
+
assert timer._normalize_memory_units(1024**2) == '1.00 MB'
|
|
93
|
+
assert timer._normalize_memory_units(1024**3) == '1.00 GB'
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_get_mean_times():
|
|
97
|
+
timer = Timer()
|
|
98
|
+
timer.times = [10.0, 20.0, 30.0]
|
|
99
|
+
|
|
100
|
+
means = timer._get_mean_times()
|
|
101
|
+
assert isinstance(means, np.ndarray)
|
|
102
|
+
assert len(means) == 3
|
|
103
|
+
assert means[0] == 10.0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|