jax-hpc-profiler 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/PKG-INFO +109 -54
- jax_hpc_profiler-0.2.1/README.md +201 -0
- jax_hpc_profiler-0.2.1/pyproject.toml +49 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/create_argparse.py +3 -3
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/main.py +3 -5
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/timer.py +68 -41
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/utils.py +28 -27
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/PKG-INFO +109 -54
- jax_hpc_profiler-0.2.0/README.md +0 -159
- jax_hpc_profiler-0.2.0/pyproject.toml +0 -23
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler/plotting.py +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -678,6 +678,19 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
678
678
|
Public License instead of this License. But first, please read
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
|
+
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
+
Keywords: jax,hpc,profiler,plotter,benchmarking
|
|
683
|
+
Classifier: Development Status :: 4 - Beta
|
|
684
|
+
Classifier: Intended Audience :: Developers
|
|
685
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
686
|
+
Classifier: Programming Language :: Python :: 3
|
|
687
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
688
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
689
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
690
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
691
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
692
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
693
|
+
Requires-Python: >=3.8
|
|
681
694
|
Description-Content-Type: text/markdown
|
|
682
695
|
License-File: LICENSE
|
|
683
696
|
Requires-Dist: numpy
|
|
@@ -686,9 +699,11 @@ Requires-Dist: matplotlib
|
|
|
686
699
|
Requires-Dist: seaborn
|
|
687
700
|
Requires-Dist: tabulate
|
|
688
701
|
|
|
689
|
-
|
|
702
|
+
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
690
703
|
|
|
691
|
-
|
|
704
|
+
# JAX HPC Profiler
|
|
705
|
+
|
|
706
|
+
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
692
707
|
|
|
693
708
|
## Table of Contents
|
|
694
709
|
- [Introduction](#introduction)
|
|
@@ -697,9 +712,10 @@ HPC Plotter is a tool designed for benchmarking and visualizing performance data
|
|
|
697
712
|
- [CSV Structure](#csv-structure)
|
|
698
713
|
- [Concatenating Files from Different Runs](#concatenating-files-from-different-runs)
|
|
699
714
|
- [Plotting CSV Data](#plotting-csv-data)
|
|
715
|
+
- [Examples](#examples)
|
|
700
716
|
|
|
701
717
|
## Introduction
|
|
702
|
-
HPC
|
|
718
|
+
JAX HPC Profiler allows users to:
|
|
703
719
|
1. Generate CSV files containing performance data.
|
|
704
720
|
2. Concatenate multiple CSV files from different runs.
|
|
705
721
|
3. Plot the performance data for analysis.
|
|
@@ -709,53 +725,80 @@ HPC Plotter allows users to:
|
|
|
709
725
|
To install the package, run the following command:
|
|
710
726
|
|
|
711
727
|
```bash
|
|
712
|
-
pip install hpc-
|
|
728
|
+
pip install jax-hpc-profiler
|
|
713
729
|
```
|
|
730
|
+
|
|
714
731
|
## Generating CSV Files Using the Timer Class
|
|
715
732
|
|
|
716
|
-
To generate CSV files, you can use the `Timer` class provided in the `
|
|
733
|
+
To generate CSV files, you can use the `Timer` class provided in the `jax_hpc_profiler.timer` module. This class helps in timing functions and saving the timing results to CSV files.
|
|
717
734
|
|
|
718
735
|
### Example Usage
|
|
719
736
|
|
|
720
737
|
```python
|
|
721
|
-
import time
|
|
722
|
-
from hpc_plotter.timer import Timer
|
|
723
738
|
import jax
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
timer = Timer()
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
739
|
+
from jax_hpc_profiler import Timer
|
|
740
|
+
|
|
741
|
+
def fcn(m, n, k):
|
|
742
|
+
return jax.numpy.dot(m, n) + k
|
|
743
|
+
|
|
744
|
+
timer = Timer(save_jaxpr=True)
|
|
745
|
+
m = jax.numpy.ones((1000, 1000))
|
|
746
|
+
n = jax.numpy.ones((1000, 1000))
|
|
747
|
+
k = jax.numpy.ones((1000, 1000))
|
|
748
|
+
|
|
749
|
+
timer.chrono_jit(fcn, m, n, k)
|
|
750
|
+
for i in range(10):
|
|
751
|
+
timer.chrono_fun(fcn, m, n, k)
|
|
752
|
+
|
|
753
|
+
meta_data = {
|
|
754
|
+
"function": "fcn",
|
|
755
|
+
"precision": "float32",
|
|
756
|
+
"x": 1000,
|
|
757
|
+
"y": 1000,
|
|
758
|
+
"z": 1000,
|
|
759
|
+
"px": 1,
|
|
760
|
+
"py": 1,
|
|
761
|
+
"backend": "NCCL",
|
|
762
|
+
"nodes": 1
|
|
763
|
+
}
|
|
764
|
+
extra_info = {
|
|
765
|
+
"done": "yes"
|
|
748
766
|
}
|
|
749
767
|
|
|
750
|
-
|
|
751
|
-
timer.print_to_csv('output.csv', **metadata)
|
|
768
|
+
timer.report("examples/profiling/test.csv", **meta_data, extra_info=extra_info)
|
|
752
769
|
```
|
|
753
770
|
|
|
771
|
+
`timer.report` has sensible defaults and this is the API for the `Timer` class:
|
|
772
|
+
|
|
773
|
+
- `csv_filename`: The path to the CSV file to save the timing data **(required)**.
|
|
774
|
+
- `function`: The name of the function being timed **(required)**.
|
|
775
|
+
- `x`: The size of the input data in the x dimension **(required)**.
|
|
776
|
+
- `y`: The size of the input data in the y dimension (by default same as x).
|
|
777
|
+
- `z`: The size of the input data in the z dimension (by default same as x).
|
|
778
|
+
- `precision`: The precision of the data (default: "float32").
|
|
779
|
+
- `px`: The number of partitions in the x dimension (default: 1).
|
|
780
|
+
- `py`: The number of partitions in the y dimension (default: 1).
|
|
781
|
+
- `backend`: The backend used for computation (default: "NCCL").
|
|
782
|
+
- `nodes`: The number of nodes used for computation (default: 1).
|
|
783
|
+
- `md_filename`: The path to the markdown file containing the compiled code and other information (default: {csv_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md).
|
|
784
|
+
- `extra_info`: Additional information to include in the report (default: {}
|
|
785
|
+
|
|
786
|
+
`px` and `py` are used to specify the data decomposition. For example, if you have a 2D array of size 1000x1000 and you partition it into 4 parts (2x2), you would set `px=2` and `py=2`.\
|
|
787
|
+
they can also be used in a single device run to specify batch size.
|
|
788
|
+
|
|
789
|
+
Some decomposition parameters are generated and that are specific to 3D data decomposition.\
|
|
790
|
+
`slab_yz` if the distributed axis is the y-axis.\
|
|
791
|
+
`slab_xy` if the distributed axis is the x-axis.\
|
|
792
|
+
`pencils` if the distributed axis are the x and y axes.
|
|
793
|
+
|
|
794
|
+
### Multi-GPU Setup
|
|
795
|
+
|
|
796
|
+
In a multi-GPU setup, the times are automatically averaged across ranks, providing a single performance metric for the entire setup.
|
|
797
|
+
|
|
754
798
|
## CSV Structure
|
|
755
799
|
|
|
756
800
|
The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.
|
|
757
801
|
|
|
758
|
-
|
|
759
802
|
### Example Directory Structure
|
|
760
803
|
|
|
761
804
|
```
|
|
@@ -790,12 +833,12 @@ root_directory/
|
|
|
790
833
|
|
|
791
834
|
## Concatenating Files from Different Runs
|
|
792
835
|
|
|
793
|
-
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs
|
|
836
|
+
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs together in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
|
|
794
837
|
|
|
795
838
|
### Example Usage
|
|
796
839
|
|
|
797
840
|
```bash
|
|
798
|
-
hpc-
|
|
841
|
+
jax-hpc-profiler concat /path/to/root_directory /path/to/output
|
|
799
842
|
```
|
|
800
843
|
|
|
801
844
|
And the output will be:
|
|
@@ -812,8 +855,6 @@ out_directory/
|
|
|
812
855
|
└── method_3.csv
|
|
813
856
|
```
|
|
814
857
|
|
|
815
|
-
|
|
816
|
-
|
|
817
858
|
## Plotting CSV Data
|
|
818
859
|
|
|
819
860
|
You can plot the performance data using the `plot` command. The plotting command provides various options to customize the plots.
|
|
@@ -821,27 +862,41 @@ You can plot the performance data using the `plot` command. The plotting command
|
|
|
821
862
|
### Usage
|
|
822
863
|
|
|
823
864
|
```bash
|
|
824
|
-
hpc-
|
|
865
|
+
jax-hpc-profiler plot -f <csv_files> [options]
|
|
825
866
|
```
|
|
826
867
|
|
|
827
|
-
|
|
828
|
-
|
|
868
|
+
### Options
|
|
829
869
|
|
|
830
870
|
- `-f, --csv_files`: List of CSV files to plot (required).
|
|
831
|
-
- `-g, --gpus`:
|
|
832
|
-
- `-d, --data_size`:
|
|
871
|
+
- `-g, --gpus`: List of number of GPUs to plot.
|
|
872
|
+
- `-d, --data_size`: List of data sizes to plot.
|
|
833
873
|
- `-fd, --filter_pdims`: List of pdims to filter (e.g., 1x4 2x2 4x8).
|
|
834
|
-
- `-ps, --
|
|
835
|
-
- `plot_all`: Plot every decomposition.
|
|
874
|
+
- `-ps, --pdim_strategy`: Strategy for plotting pdims. This argument can be multiple ones (`plot_all`, `plot_fastest`, `slab_yz`, `slab_xy`, `pencils`).
|
|
875
|
+
- `plot_all`: Plot every decomposition.
|
|
836
876
|
- `plot_fastest`: Plot the fastest decomposition.
|
|
837
|
-
- `-
|
|
838
|
-
- `-fn, --function_name`: Function
|
|
839
|
-
- `-
|
|
840
|
-
- `-
|
|
877
|
+
- `-pr, --precision`: Precision to filter by. This argument can be multiple ones (`float32`, `float64`).
|
|
878
|
+
- `-fn, --function_name`: Function names to filter. This argument can be multiple ones.
|
|
879
|
+
- `-pt, --plot_times`: Time columns to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_time`, `last_time`). Note: You cannot plot memory and time together.
|
|
880
|
+
- `-pm, --plot_memory`: Memory columns to plot (`generated_code`, `argument_size`, `output_size`, `temp_size`). Note: You cannot plot memory and time together.
|
|
881
|
+
- `-mu, --memory_units`: Memory units to plot (`KB`, `MB`, `GB`, `TB`).
|
|
841
882
|
- `-fs, --figure_size`: Figure size.
|
|
842
|
-
- `-nl, --nodes_in_label`: Use node names in labels.
|
|
843
883
|
- `-o, --output`: Output file (if none then only show plot).
|
|
844
884
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
845
|
-
- `-pd, --print_decompositions`: Print decompositions on plot (
|
|
846
|
-
- `-b, --backends`: List of backends to include.
|
|
847
|
-
- `-sc, --scaling`: Scaling type (`Weak
|
|
885
|
+
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
886
|
+
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
887
|
+
- `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
|
|
888
|
+
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
889
|
+
|
|
890
|
+
## Examples
|
|
891
|
+
|
|
892
|
+
The repository includes examples for both profiling and plotting.
|
|
893
|
+
|
|
894
|
+
### Profiling Example
|
|
895
|
+
|
|
896
|
+
See the `examples/profiling` directory for profiling examples, including `function.py`, `test.csv`, and the generated markdown report.
|
|
897
|
+
|
|
898
|
+
### Plotting Example
|
|
899
|
+
|
|
900
|
+
See the `examples/plotting` directory for plotting examples, including `generator.py`, `sample_data1.csv`, `sample_data2.csv`, and `sample_data3.csv`.
|
|
901
|
+
|
|
902
|
+
a multi GPU example comparing distributed FFT can be found here [jaxdecomp-bechmarks](https://github.com/ASKabalan/jaxdecomp-benchmarks)
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
2
|
+
|
|
3
|
+
# JAX HPC Profiler
|
|
4
|
+
|
|
5
|
+
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
6
|
+
|
|
7
|
+
## Table of Contents
|
|
8
|
+
- [Introduction](#introduction)
|
|
9
|
+
- [Installation](#installation)
|
|
10
|
+
- [Generating CSV Files Using the Timer Class](#generating-csv-files-using-the-timer-class)
|
|
11
|
+
- [CSV Structure](#csv-structure)
|
|
12
|
+
- [Concatenating Files from Different Runs](#concatenating-files-from-different-runs)
|
|
13
|
+
- [Plotting CSV Data](#plotting-csv-data)
|
|
14
|
+
- [Examples](#examples)
|
|
15
|
+
|
|
16
|
+
## Introduction
|
|
17
|
+
JAX HPC Profiler allows users to:
|
|
18
|
+
1. Generate CSV files containing performance data.
|
|
19
|
+
2. Concatenate multiple CSV files from different runs.
|
|
20
|
+
3. Plot the performance data for analysis.
|
|
21
|
+
|
|
22
|
+
## Installation
|
|
23
|
+
|
|
24
|
+
To install the package, run the following command:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
pip install jax-hpc-profiler
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Generating CSV Files Using the Timer Class
|
|
31
|
+
|
|
32
|
+
To generate CSV files, you can use the `Timer` class provided in the `jax_hpc_profiler.timer` module. This class helps in timing functions and saving the timing results to CSV files.
|
|
33
|
+
|
|
34
|
+
### Example Usage
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
import jax
|
|
38
|
+
from jax_hpc_profiler import Timer
|
|
39
|
+
|
|
40
|
+
def fcn(m, n, k):
|
|
41
|
+
return jax.numpy.dot(m, n) + k
|
|
42
|
+
|
|
43
|
+
timer = Timer(save_jaxpr=True)
|
|
44
|
+
m = jax.numpy.ones((1000, 1000))
|
|
45
|
+
n = jax.numpy.ones((1000, 1000))
|
|
46
|
+
k = jax.numpy.ones((1000, 1000))
|
|
47
|
+
|
|
48
|
+
timer.chrono_jit(fcn, m, n, k)
|
|
49
|
+
for i in range(10):
|
|
50
|
+
timer.chrono_fun(fcn, m, n, k)
|
|
51
|
+
|
|
52
|
+
meta_data = {
|
|
53
|
+
"function": "fcn",
|
|
54
|
+
"precision": "float32",
|
|
55
|
+
"x": 1000,
|
|
56
|
+
"y": 1000,
|
|
57
|
+
"z": 1000,
|
|
58
|
+
"px": 1,
|
|
59
|
+
"py": 1,
|
|
60
|
+
"backend": "NCCL",
|
|
61
|
+
"nodes": 1
|
|
62
|
+
}
|
|
63
|
+
extra_info = {
|
|
64
|
+
"done": "yes"
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
timer.report("examples/profiling/test.csv", **meta_data, extra_info=extra_info)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
`timer.report` has sensible defaults and this is the API for the `Timer` class:
|
|
71
|
+
|
|
72
|
+
- `csv_filename`: The path to the CSV file to save the timing data **(required)**.
|
|
73
|
+
- `function`: The name of the function being timed **(required)**.
|
|
74
|
+
- `x`: The size of the input data in the x dimension **(required)**.
|
|
75
|
+
- `y`: The size of the input data in the y dimension (by default same as x).
|
|
76
|
+
- `z`: The size of the input data in the z dimension (by default same as x).
|
|
77
|
+
- `precision`: The precision of the data (default: "float32").
|
|
78
|
+
- `px`: The number of partitions in the x dimension (default: 1).
|
|
79
|
+
- `py`: The number of partitions in the y dimension (default: 1).
|
|
80
|
+
- `backend`: The backend used for computation (default: "NCCL").
|
|
81
|
+
- `nodes`: The number of nodes used for computation (default: 1).
|
|
82
|
+
- `md_filename`: The path to the markdown file containing the compiled code and other information (default: {csv_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md).
|
|
83
|
+
- `extra_info`: Additional information to include in the report (default: {}
|
|
84
|
+
|
|
85
|
+
`px` and `py` are used to specify the data decomposition. For example, if you have a 2D array of size 1000x1000 and you partition it into 4 parts (2x2), you would set `px=2` and `py=2`.\
|
|
86
|
+
they can also be used in a single device run to specify batch size.
|
|
87
|
+
|
|
88
|
+
Some decomposition parameters are generated and that are specific to 3D data decomposition.\
|
|
89
|
+
`slab_yz` if the distributed axis is the y-axis.\
|
|
90
|
+
`slab_xy` if the distributed axis is the x-axis.\
|
|
91
|
+
`pencils` if the distributed axis are the x and y axes.
|
|
92
|
+
|
|
93
|
+
### Multi-GPU Setup
|
|
94
|
+
|
|
95
|
+
In a multi-GPU setup, the times are automatically averaged across ranks, providing a single performance metric for the entire setup.
|
|
96
|
+
|
|
97
|
+
## CSV Structure
|
|
98
|
+
|
|
99
|
+
The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.
|
|
100
|
+
|
|
101
|
+
### Example Directory Structure
|
|
102
|
+
|
|
103
|
+
```
|
|
104
|
+
root_directory/
|
|
105
|
+
├── gpu_1/
|
|
106
|
+
│ ├── 2/
|
|
107
|
+
│ │ ├── method_1.csv
|
|
108
|
+
│ │ ├── method_2.csv
|
|
109
|
+
│ │ └── method_3.csv
|
|
110
|
+
│ ├── 4/
|
|
111
|
+
│ │ ├── method_1.csv
|
|
112
|
+
│ │ ├── method_2.csv
|
|
113
|
+
│ │ └── method_3.csv
|
|
114
|
+
│ └── 8/
|
|
115
|
+
│ ├── method_1.csv
|
|
116
|
+
│ ├── method_2.csv
|
|
117
|
+
│ └── method_3.csv
|
|
118
|
+
└── gpu_2/
|
|
119
|
+
├── 2/
|
|
120
|
+
│ ├── method_1.csv
|
|
121
|
+
│ ├── method_2.csv
|
|
122
|
+
│ └── method_3.csv
|
|
123
|
+
├── 4/
|
|
124
|
+
│ ├── method_1.csv
|
|
125
|
+
│ ├── method_2.csv
|
|
126
|
+
│ └── method_3.csv
|
|
127
|
+
└── 8/
|
|
128
|
+
├── method_1.csv
|
|
129
|
+
├── method_2.csv
|
|
130
|
+
└── method_3.csv
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
## Concatenating Files from Different Runs
|
|
134
|
+
|
|
135
|
+
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs together in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
|
|
136
|
+
|
|
137
|
+
### Example Usage
|
|
138
|
+
|
|
139
|
+
```bash
|
|
140
|
+
jax-hpc-profiler concat /path/to/root_directory /path/to/output
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
And the output will be:
|
|
144
|
+
|
|
145
|
+
```
|
|
146
|
+
out_directory/
|
|
147
|
+
├── gpu_1/
|
|
148
|
+
│ ├── method_1.csv
|
|
149
|
+
│ ├── method_2.csv
|
|
150
|
+
│ └── method_3.csv
|
|
151
|
+
└── gpu_2/
|
|
152
|
+
├── method_1.csv
|
|
153
|
+
├── method_2.csv
|
|
154
|
+
└── method_3.csv
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
## Plotting CSV Data
|
|
158
|
+
|
|
159
|
+
You can plot the performance data using the `plot` command. The plotting command provides various options to customize the plots.
|
|
160
|
+
|
|
161
|
+
### Usage
|
|
162
|
+
|
|
163
|
+
```bash
|
|
164
|
+
jax-hpc-profiler plot -f <csv_files> [options]
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
### Options
|
|
168
|
+
|
|
169
|
+
- `-f, --csv_files`: List of CSV files to plot (required).
|
|
170
|
+
- `-g, --gpus`: List of number of GPUs to plot.
|
|
171
|
+
- `-d, --data_size`: List of data sizes to plot.
|
|
172
|
+
- `-fd, --filter_pdims`: List of pdims to filter (e.g., 1x4 2x2 4x8).
|
|
173
|
+
- `-ps, --pdim_strategy`: Strategy for plotting pdims. This argument can be multiple ones (`plot_all`, `plot_fastest`, `slab_yz`, `slab_xy`, `pencils`).
|
|
174
|
+
- `plot_all`: Plot every decomposition.
|
|
175
|
+
- `plot_fastest`: Plot the fastest decomposition.
|
|
176
|
+
- `-pr, --precision`: Precision to filter by. This argument can be multiple ones (`float32`, `float64`).
|
|
177
|
+
- `-fn, --function_name`: Function names to filter. This argument can be multiple ones.
|
|
178
|
+
- `-pt, --plot_times`: Time columns to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_time`, `last_time`). Note: You cannot plot memory and time together.
|
|
179
|
+
- `-pm, --plot_memory`: Memory columns to plot (`generated_code`, `argument_size`, `output_size`, `temp_size`). Note: You cannot plot memory and time together.
|
|
180
|
+
- `-mu, --memory_units`: Memory units to plot (`KB`, `MB`, `GB`, `TB`).
|
|
181
|
+
- `-fs, --figure_size`: Figure size.
|
|
182
|
+
- `-o, --output`: Output file (if none then only show plot).
|
|
183
|
+
- `-db, --dark_bg`: Use dark background for plotting.
|
|
184
|
+
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
185
|
+
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
186
|
+
- `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
|
|
187
|
+
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
188
|
+
|
|
189
|
+
## Examples
|
|
190
|
+
|
|
191
|
+
The repository includes examples for both profiling and plotting.
|
|
192
|
+
|
|
193
|
+
### Profiling Example
|
|
194
|
+
|
|
195
|
+
See the `examples/profiling` directory for profiling examples, including `function.py`, `test.csv`, and the generated markdown report.
|
|
196
|
+
|
|
197
|
+
### Plotting Example
|
|
198
|
+
|
|
199
|
+
See the `examples/plotting` directory for plotting examples, including `generator.py`, `sample_data1.csv`, `sample_data2.csv`, and `sample_data3.csv`.
|
|
200
|
+
|
|
201
|
+
a multi GPU example comparing distributed FFT can be found here [jaxdecomp-bechmarks](https://github.com/ASKabalan/jaxdecomp-benchmarks)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "jax_hpc_profiler"
|
|
7
|
+
version = "0.2.1"
|
|
8
|
+
description = "HPC Plotter and profiler for benchmarking data made for JAX"
|
|
9
|
+
authors = [
|
|
10
|
+
{ name="Wassim Kabalan" }
|
|
11
|
+
]
|
|
12
|
+
dependencies = [
|
|
13
|
+
"numpy",
|
|
14
|
+
"pandas",
|
|
15
|
+
"matplotlib",
|
|
16
|
+
"seaborn",
|
|
17
|
+
"tabulate"
|
|
18
|
+
]
|
|
19
|
+
readme = "README.md"
|
|
20
|
+
license = { file = "LICENSE" }
|
|
21
|
+
requires-python = ">=3.8"
|
|
22
|
+
keywords = ["jax", "hpc", "profiler", "plotter", "benchmarking"]
|
|
23
|
+
|
|
24
|
+
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
|
25
|
+
classifiers = [
|
|
26
|
+
"Development Status :: 4 - Beta",
|
|
27
|
+
|
|
28
|
+
# Indicate who your project is intended for
|
|
29
|
+
"Intended Audience :: Developers",
|
|
30
|
+
|
|
31
|
+
# Pick your license as you wish
|
|
32
|
+
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
|
33
|
+
|
|
34
|
+
# Specify the Python versions you support here. In particular, ensure
|
|
35
|
+
# that you indicate you support Python 3. These classifiers are *not*
|
|
36
|
+
# checked by "pip install". See instead "requires-python" key in this file.
|
|
37
|
+
"Programming Language :: Python :: 3",
|
|
38
|
+
"Programming Language :: Python :: 3.8",
|
|
39
|
+
"Programming Language :: Python :: 3.9",
|
|
40
|
+
"Programming Language :: Python :: 3.10",
|
|
41
|
+
"Programming Language :: Python :: 3.11",
|
|
42
|
+
"Programming Language :: Python :: 3.12",
|
|
43
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
|
|
47
|
+
|
|
48
|
+
[project.scripts]
|
|
49
|
+
jhp = "jax_hpc_profiler.main:main"
|
|
@@ -89,9 +89,9 @@ def create_argparser():
|
|
|
89
89
|
],
|
|
90
90
|
help='Memory columns to plot')
|
|
91
91
|
plot_parser.add_argument('-mu',
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
'--memory_units',
|
|
93
|
+
default='GB',
|
|
94
|
+
help='Memory units to plot (KB, MB, GB, TB)')
|
|
95
95
|
|
|
96
96
|
# Plot customization arguments
|
|
97
97
|
plot_parser.add_argument('-fs',
|
|
@@ -15,7 +15,7 @@ def main():
|
|
|
15
15
|
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
16
16
|
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
17
17
|
args.data_size, args.filter_pdims, args.pdim_strategy,
|
|
18
|
-
args.backends,args.memory_units)
|
|
18
|
+
args.backends, args.memory_units)
|
|
19
19
|
if len(dataframes) == 0:
|
|
20
20
|
print(f"No dataframes found for the given arguments. Exiting...")
|
|
21
21
|
sys.exit(1)
|
|
@@ -29,12 +29,10 @@ def main():
|
|
|
29
29
|
if data_size in available_data_sizes
|
|
30
30
|
]
|
|
31
31
|
if len(args.gpus) == 0:
|
|
32
|
-
print(
|
|
33
|
-
f"No dataframes found for the given GPUs. Exiting...")
|
|
32
|
+
print(f"No dataframes found for the given GPUs. Exiting...")
|
|
34
33
|
sys.exit(1)
|
|
35
34
|
if len(args.data_size) == 0:
|
|
36
|
-
print(
|
|
37
|
-
f"No dataframes found for the given data sizes. Exiting...")
|
|
35
|
+
print(f"No dataframes found for the given data sizes. Exiting...")
|
|
38
36
|
sys.exit(1)
|
|
39
37
|
|
|
40
38
|
if args.scaling == 'Weak':
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable, List
|
|
4
|
+
from typing import Any, Callable, List, Tuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
9
|
from jax import make_jaxpr
|
|
10
|
+
from jax.experimental import mesh_utils
|
|
10
11
|
from jax.experimental.shard_map import shard_map
|
|
11
12
|
from jax.sharding import Mesh, NamedSharding
|
|
12
13
|
from jax.sharding import PartitionSpec as P
|
|
@@ -22,6 +23,19 @@ class Timer:
|
|
|
22
23
|
self.compiled_code = {}
|
|
23
24
|
self.save_jaxpr = save_jaxpr
|
|
24
25
|
|
|
26
|
+
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
+
if cost_analysis is None:
|
|
28
|
+
return None
|
|
29
|
+
return cost_analysis[0]['flops']
|
|
30
|
+
|
|
31
|
+
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
32
|
+
if memory_analysis is None:
|
|
33
|
+
return None, None, None, None
|
|
34
|
+
return (memory_analysis.generated_code_size_in_bytes,
|
|
35
|
+
memory_analysis.argument_size_in_bytes,
|
|
36
|
+
memory_analysis.output_size_in_bytes,
|
|
37
|
+
memory_analysis.temp_size_in_bytes)
|
|
38
|
+
|
|
25
39
|
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
26
40
|
start = time.perf_counter()
|
|
27
41
|
out = jax.jit(fun)(*args)
|
|
@@ -38,18 +52,17 @@ class Timer:
|
|
|
38
52
|
|
|
39
53
|
lowered = jax.jit(fun).lower(*args)
|
|
40
54
|
compiled = lowered.compile()
|
|
41
|
-
memory_analysis =
|
|
55
|
+
memory_analysis = self._read_memory_analysis(
|
|
56
|
+
compiled.memory_analysis())
|
|
57
|
+
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
58
|
+
|
|
42
59
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
43
60
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
44
|
-
self.profiling_data["FLOPS"] =
|
|
45
|
-
self.profiling_data[
|
|
46
|
-
|
|
47
|
-
self.profiling_data[
|
|
48
|
-
|
|
49
|
-
self.profiling_data[
|
|
50
|
-
"output_size"] = memory_analysis.output_size_in_bytes
|
|
51
|
-
self.profiling_data["temp_size"] = memory_analysis.temp_size_in_bytes
|
|
52
|
-
|
|
61
|
+
self.profiling_data["FLOPS"] = cost_analysis
|
|
62
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
63
|
+
self.profiling_data["argument_size"] = memory_analysis[0]
|
|
64
|
+
self.profiling_data["output_size"] = memory_analysis[0]
|
|
65
|
+
self.profiling_data["temp_size"] = memory_analysis[0]
|
|
53
66
|
return out
|
|
54
67
|
|
|
55
68
|
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
@@ -63,56 +76,62 @@ class Timer:
|
|
|
63
76
|
self.times.append((end - start) * 1e3)
|
|
64
77
|
return out
|
|
65
78
|
|
|
66
|
-
def _get_mean_times(self
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
79
|
+
def _get_mean_times(self) -> np.ndarray:
|
|
80
|
+
if jax.device_count() == 1:
|
|
81
|
+
return np.array(self.times)
|
|
82
|
+
|
|
83
|
+
devices = mesh_utils.create_device_mesh((jax.device_count(), ))
|
|
84
|
+
mesh = Mesh(devices, ('x', ))
|
|
85
|
+
sharding = NamedSharding(mesh, P('x'))
|
|
86
|
+
|
|
87
|
+
times_array = jnp.array(self.times)
|
|
88
|
+
global_shape = (jax.device_count(), times_array.shape[0])
|
|
89
|
+
global_times = jax.make_array_from_callback(
|
|
90
|
+
shape=global_shape,
|
|
91
|
+
sharding=sharding,
|
|
92
|
+
data_callback=lambda x: times_array)
|
|
73
93
|
|
|
74
94
|
@partial(shard_map,
|
|
75
95
|
mesh=mesh,
|
|
76
|
-
in_specs=
|
|
96
|
+
in_specs=P('x'),
|
|
77
97
|
out_specs=P(),
|
|
78
98
|
check_rep=False)
|
|
79
99
|
def get_mean_times(times):
|
|
80
|
-
|
|
81
|
-
for axis_name in valid_letters[1:]:
|
|
82
|
-
mean = jax.lax.pmean(mean, axis_name=axis_name)
|
|
83
|
-
return mean
|
|
100
|
+
return jax.lax.pmean(times, axis_name='x')
|
|
84
101
|
|
|
85
|
-
times_array = get_mean_times(
|
|
102
|
+
times_array = get_mean_times(global_times)
|
|
86
103
|
times_array.block_until_ready()
|
|
87
|
-
return times_array
|
|
104
|
+
return np.array(times_array.addressable_data(0))
|
|
88
105
|
|
|
89
106
|
def report(self,
|
|
90
107
|
csv_filename: str,
|
|
91
108
|
function: str,
|
|
92
|
-
precision: str,
|
|
93
109
|
x: int,
|
|
94
|
-
y: int,
|
|
95
|
-
z: int,
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
110
|
+
y: int | None = None,
|
|
111
|
+
z: int | None = None,
|
|
112
|
+
precision: str = "float32",
|
|
113
|
+
px: int = 1,
|
|
114
|
+
py: int = 1,
|
|
115
|
+
backend: str = "NCCL",
|
|
116
|
+
nodes: int = 1,
|
|
101
117
|
md_filename: str | None = None,
|
|
102
118
|
extra_info: dict = {}):
|
|
103
|
-
times_array = jnp.array(self.times)
|
|
104
119
|
|
|
105
120
|
if md_filename is None:
|
|
106
|
-
dirname, filename = os.path.dirname(
|
|
121
|
+
dirname, filename = os.path.dirname(
|
|
122
|
+
csv_filename), os.path.splitext(
|
|
123
|
+
os.path.basename(csv_filename))[0]
|
|
107
124
|
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
108
|
-
print(
|
|
125
|
+
print(
|
|
126
|
+
f"report_folder: {report_folder} csv_filename: {csv_filename}")
|
|
109
127
|
os.makedirs(report_folder, exist_ok=True)
|
|
110
128
|
md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
111
129
|
|
|
112
|
-
if
|
|
113
|
-
|
|
130
|
+
y = x if y is None else y
|
|
131
|
+
z = x if z is None else z
|
|
132
|
+
|
|
133
|
+
times_array = self._get_mean_times()
|
|
114
134
|
|
|
115
|
-
times_array = np.array(times_array)
|
|
116
135
|
min_time = np.min(times_array)
|
|
117
136
|
max_time = np.max(times_array)
|
|
118
137
|
mean_time = np.mean(times_array)
|
|
@@ -163,10 +182,18 @@ class Timer:
|
|
|
163
182
|
with open(md_filename, 'w') as f:
|
|
164
183
|
f.write(f"# Reporting for {function}\n")
|
|
165
184
|
f.write(f"## Parameters\n")
|
|
166
|
-
|
|
185
|
+
keys = list(param_dict.keys())
|
|
186
|
+
values = list(param_dict.values())
|
|
187
|
+
f.write(
|
|
188
|
+
tabulate(param_dict.items(),
|
|
189
|
+
headers=["Parameter", "Value"],
|
|
190
|
+
tablefmt='github'))
|
|
167
191
|
f.write("\n---\n")
|
|
168
192
|
f.write(f"## Profiling Data\n")
|
|
169
|
-
f.write(
|
|
193
|
+
f.write(
|
|
194
|
+
tabulate(profiling_result.items(),
|
|
195
|
+
headers=["Parameter", "Value"],
|
|
196
|
+
tablefmt='github'))
|
|
170
197
|
f.write("\n---\n")
|
|
171
198
|
f.write(f"## Compiled Code\n")
|
|
172
199
|
f.write(f"```hlo\n")
|
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
from matplotlib.axes import Axes
|
|
8
8
|
|
|
9
|
+
|
|
9
10
|
def inspect_data(dataframes: Dict[str, pd.DataFrame]):
|
|
10
11
|
"""
|
|
11
12
|
Inspect the dataframes.
|
|
@@ -203,17 +204,17 @@ def concatenate_csvs(root_dir: str, output_dir: str):
|
|
|
203
204
|
if file.endswith('.csv'):
|
|
204
205
|
csv_file_path = os.path.join(root, file)
|
|
205
206
|
print(f'Concatenating {csv_file_path}...')
|
|
206
|
-
df = pd.read_csv(
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
207
|
+
df = pd.read_csv(csv_file_path,
|
|
208
|
+
header=None,
|
|
209
|
+
names=[
|
|
210
|
+
"function", "precision", "x", "y",
|
|
211
|
+
"z", "px", "py", "backend", "nodes",
|
|
212
|
+
"jit_time", "min_time", "max_time",
|
|
213
|
+
"mean_time", "std_div", "last_time",
|
|
214
|
+
"generated_code", "argument_size",
|
|
215
|
+
"output_size", "temp_size", "flops"
|
|
216
|
+
],
|
|
217
|
+
index_col=False)
|
|
217
218
|
if file not in combined_dfs:
|
|
218
219
|
combined_dfs[file] = df
|
|
219
220
|
else:
|
|
@@ -340,23 +341,23 @@ def clean_up_csv(
|
|
|
340
341
|
if pdims:
|
|
341
342
|
px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
|
|
342
343
|
df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
|
|
343
|
-
|
|
344
|
-
# convert memory units columns to remquested memory_units
|
|
344
|
+
|
|
345
|
+
# convert memory units columns to remquested memory_units
|
|
345
346
|
match memory_units:
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
347
|
+
case 'KB':
|
|
348
|
+
factor = 1024
|
|
349
|
+
case 'MB':
|
|
350
|
+
factor = 1024**2
|
|
351
|
+
case 'GB':
|
|
352
|
+
factor = 1024**3
|
|
353
|
+
case 'TB':
|
|
354
|
+
factor = 1024**4
|
|
355
|
+
case _:
|
|
356
|
+
factor = 1
|
|
357
|
+
|
|
357
358
|
df['generated_code'] = df['generated_code'] / factor
|
|
358
|
-
df['argument_size'] = df['argument_size'] / factor
|
|
359
|
-
df['output_size'] = df['output_size'] / factor
|
|
359
|
+
df['argument_size'] = df['argument_size'] / factor
|
|
360
|
+
df['output_size'] = df['output_size'] / factor
|
|
360
361
|
df['temp_size'] = df['temp_size'] / factor
|
|
361
362
|
# in case of the same test is run multiple times, keep the last one
|
|
362
363
|
df = df.drop_duplicates(subset=[
|
|
@@ -383,7 +384,7 @@ def clean_up_csv(
|
|
|
383
384
|
df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
|
|
384
385
|
df.drop(columns=['px', 'py'], inplace=True)
|
|
385
386
|
if not 'plot_all' in pdims_strategy:
|
|
386
|
-
|
|
387
|
+
df = df[df['decomp'].isin(pdims_strategy)]
|
|
387
388
|
# check available gpus in dataset
|
|
388
389
|
available_gpu_counts.update(df['gpus'].unique())
|
|
389
390
|
available_data_sizes.update(df['x'].unique())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -678,6 +678,19 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
678
678
|
Public License instead of this License. But first, please read
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
|
+
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
+
Keywords: jax,hpc,profiler,plotter,benchmarking
|
|
683
|
+
Classifier: Development Status :: 4 - Beta
|
|
684
|
+
Classifier: Intended Audience :: Developers
|
|
685
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
686
|
+
Classifier: Programming Language :: Python :: 3
|
|
687
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
688
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
689
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
690
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
691
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
692
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
693
|
+
Requires-Python: >=3.8
|
|
681
694
|
Description-Content-Type: text/markdown
|
|
682
695
|
License-File: LICENSE
|
|
683
696
|
Requires-Dist: numpy
|
|
@@ -686,9 +699,11 @@ Requires-Dist: matplotlib
|
|
|
686
699
|
Requires-Dist: seaborn
|
|
687
700
|
Requires-Dist: tabulate
|
|
688
701
|
|
|
689
|
-
|
|
702
|
+
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
690
703
|
|
|
691
|
-
|
|
704
|
+
# JAX HPC Profiler
|
|
705
|
+
|
|
706
|
+
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
692
707
|
|
|
693
708
|
## Table of Contents
|
|
694
709
|
- [Introduction](#introduction)
|
|
@@ -697,9 +712,10 @@ HPC Plotter is a tool designed for benchmarking and visualizing performance data
|
|
|
697
712
|
- [CSV Structure](#csv-structure)
|
|
698
713
|
- [Concatenating Files from Different Runs](#concatenating-files-from-different-runs)
|
|
699
714
|
- [Plotting CSV Data](#plotting-csv-data)
|
|
715
|
+
- [Examples](#examples)
|
|
700
716
|
|
|
701
717
|
## Introduction
|
|
702
|
-
HPC
|
|
718
|
+
JAX HPC Profiler allows users to:
|
|
703
719
|
1. Generate CSV files containing performance data.
|
|
704
720
|
2. Concatenate multiple CSV files from different runs.
|
|
705
721
|
3. Plot the performance data for analysis.
|
|
@@ -709,53 +725,80 @@ HPC Plotter allows users to:
|
|
|
709
725
|
To install the package, run the following command:
|
|
710
726
|
|
|
711
727
|
```bash
|
|
712
|
-
pip install hpc-
|
|
728
|
+
pip install jax-hpc-profiler
|
|
713
729
|
```
|
|
730
|
+
|
|
714
731
|
## Generating CSV Files Using the Timer Class
|
|
715
732
|
|
|
716
|
-
To generate CSV files, you can use the `Timer` class provided in the `
|
|
733
|
+
To generate CSV files, you can use the `Timer` class provided in the `jax_hpc_profiler.timer` module. This class helps in timing functions and saving the timing results to CSV files.
|
|
717
734
|
|
|
718
735
|
### Example Usage
|
|
719
736
|
|
|
720
737
|
```python
|
|
721
|
-
import time
|
|
722
|
-
from hpc_plotter.timer import Timer
|
|
723
738
|
import jax
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
timer = Timer()
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
739
|
+
from jax_hpc_profiler import Timer
|
|
740
|
+
|
|
741
|
+
def fcn(m, n, k):
|
|
742
|
+
return jax.numpy.dot(m, n) + k
|
|
743
|
+
|
|
744
|
+
timer = Timer(save_jaxpr=True)
|
|
745
|
+
m = jax.numpy.ones((1000, 1000))
|
|
746
|
+
n = jax.numpy.ones((1000, 1000))
|
|
747
|
+
k = jax.numpy.ones((1000, 1000))
|
|
748
|
+
|
|
749
|
+
timer.chrono_jit(fcn, m, n, k)
|
|
750
|
+
for i in range(10):
|
|
751
|
+
timer.chrono_fun(fcn, m, n, k)
|
|
752
|
+
|
|
753
|
+
meta_data = {
|
|
754
|
+
"function": "fcn",
|
|
755
|
+
"precision": "float32",
|
|
756
|
+
"x": 1000,
|
|
757
|
+
"y": 1000,
|
|
758
|
+
"z": 1000,
|
|
759
|
+
"px": 1,
|
|
760
|
+
"py": 1,
|
|
761
|
+
"backend": "NCCL",
|
|
762
|
+
"nodes": 1
|
|
763
|
+
}
|
|
764
|
+
extra_info = {
|
|
765
|
+
"done": "yes"
|
|
748
766
|
}
|
|
749
767
|
|
|
750
|
-
|
|
751
|
-
timer.print_to_csv('output.csv', **metadata)
|
|
768
|
+
timer.report("examples/profiling/test.csv", **meta_data, extra_info=extra_info)
|
|
752
769
|
```
|
|
753
770
|
|
|
771
|
+
`timer.report` has sensible defaults and this is the API for the `Timer` class:
|
|
772
|
+
|
|
773
|
+
- `csv_filename`: The path to the CSV file to save the timing data **(required)**.
|
|
774
|
+
- `function`: The name of the function being timed **(required)**.
|
|
775
|
+
- `x`: The size of the input data in the x dimension **(required)**.
|
|
776
|
+
- `y`: The size of the input data in the y dimension (by default same as x).
|
|
777
|
+
- `z`: The size of the input data in the z dimension (by default same as x).
|
|
778
|
+
- `precision`: The precision of the data (default: "float32").
|
|
779
|
+
- `px`: The number of partitions in the x dimension (default: 1).
|
|
780
|
+
- `py`: The number of partitions in the y dimension (default: 1).
|
|
781
|
+
- `backend`: The backend used for computation (default: "NCCL").
|
|
782
|
+
- `nodes`: The number of nodes used for computation (default: 1).
|
|
783
|
+
- `md_filename`: The path to the markdown file containing the compiled code and other information (default: {csv_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md).
|
|
784
|
+
- `extra_info`: Additional information to include in the report (default: {}
|
|
785
|
+
|
|
786
|
+
`px` and `py` are used to specify the data decomposition. For example, if you have a 2D array of size 1000x1000 and you partition it into 4 parts (2x2), you would set `px=2` and `py=2`.\
|
|
787
|
+
they can also be used in a single device run to specify batch size.
|
|
788
|
+
|
|
789
|
+
Some decomposition parameters are generated and that are specific to 3D data decomposition.\
|
|
790
|
+
`slab_yz` if the distributed axis is the y-axis.\
|
|
791
|
+
`slab_xy` if the distributed axis is the x-axis.\
|
|
792
|
+
`pencils` if the distributed axis are the x and y axes.
|
|
793
|
+
|
|
794
|
+
### Multi-GPU Setup
|
|
795
|
+
|
|
796
|
+
In a multi-GPU setup, the times are automatically averaged across ranks, providing a single performance metric for the entire setup.
|
|
797
|
+
|
|
754
798
|
## CSV Structure
|
|
755
799
|
|
|
756
800
|
The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.
|
|
757
801
|
|
|
758
|
-
|
|
759
802
|
### Example Directory Structure
|
|
760
803
|
|
|
761
804
|
```
|
|
@@ -790,12 +833,12 @@ root_directory/
|
|
|
790
833
|
|
|
791
834
|
## Concatenating Files from Different Runs
|
|
792
835
|
|
|
793
|
-
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs
|
|
836
|
+
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs together in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
|
|
794
837
|
|
|
795
838
|
### Example Usage
|
|
796
839
|
|
|
797
840
|
```bash
|
|
798
|
-
hpc-
|
|
841
|
+
jax-hpc-profiler concat /path/to/root_directory /path/to/output
|
|
799
842
|
```
|
|
800
843
|
|
|
801
844
|
And the output will be:
|
|
@@ -812,8 +855,6 @@ out_directory/
|
|
|
812
855
|
└── method_3.csv
|
|
813
856
|
```
|
|
814
857
|
|
|
815
|
-
|
|
816
|
-
|
|
817
858
|
## Plotting CSV Data
|
|
818
859
|
|
|
819
860
|
You can plot the performance data using the `plot` command. The plotting command provides various options to customize the plots.
|
|
@@ -821,27 +862,41 @@ You can plot the performance data using the `plot` command. The plotting command
|
|
|
821
862
|
### Usage
|
|
822
863
|
|
|
823
864
|
```bash
|
|
824
|
-
hpc-
|
|
865
|
+
jax-hpc-profiler plot -f <csv_files> [options]
|
|
825
866
|
```
|
|
826
867
|
|
|
827
|
-
|
|
828
|
-
|
|
868
|
+
### Options
|
|
829
869
|
|
|
830
870
|
- `-f, --csv_files`: List of CSV files to plot (required).
|
|
831
|
-
- `-g, --gpus`:
|
|
832
|
-
- `-d, --data_size`:
|
|
871
|
+
- `-g, --gpus`: List of number of GPUs to plot.
|
|
872
|
+
- `-d, --data_size`: List of data sizes to plot.
|
|
833
873
|
- `-fd, --filter_pdims`: List of pdims to filter (e.g., 1x4 2x2 4x8).
|
|
834
|
-
- `-ps, --
|
|
835
|
-
- `plot_all`: Plot every decomposition.
|
|
874
|
+
- `-ps, --pdim_strategy`: Strategy for plotting pdims. This argument can be multiple ones (`plot_all`, `plot_fastest`, `slab_yz`, `slab_xy`, `pencils`).
|
|
875
|
+
- `plot_all`: Plot every decomposition.
|
|
836
876
|
- `plot_fastest`: Plot the fastest decomposition.
|
|
837
|
-
- `-
|
|
838
|
-
- `-fn, --function_name`: Function
|
|
839
|
-
- `-
|
|
840
|
-
- `-
|
|
877
|
+
- `-pr, --precision`: Precision to filter by. This argument can be multiple ones (`float32`, `float64`).
|
|
878
|
+
- `-fn, --function_name`: Function names to filter. This argument can be multiple ones.
|
|
879
|
+
- `-pt, --plot_times`: Time columns to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_time`, `last_time`). Note: You cannot plot memory and time together.
|
|
880
|
+
- `-pm, --plot_memory`: Memory columns to plot (`generated_code`, `argument_size`, `output_size`, `temp_size`). Note: You cannot plot memory and time together.
|
|
881
|
+
- `-mu, --memory_units`: Memory units to plot (`KB`, `MB`, `GB`, `TB`).
|
|
841
882
|
- `-fs, --figure_size`: Figure size.
|
|
842
|
-
- `-nl, --nodes_in_label`: Use node names in labels.
|
|
843
883
|
- `-o, --output`: Output file (if none then only show plot).
|
|
844
884
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
845
|
-
- `-pd, --print_decompositions`: Print decompositions on plot (
|
|
846
|
-
- `-b, --backends`: List of backends to include.
|
|
847
|
-
- `-sc, --scaling`: Scaling type (`Weak
|
|
885
|
+
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
886
|
+
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
887
|
+
- `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
|
|
888
|
+
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
889
|
+
|
|
890
|
+
## Examples
|
|
891
|
+
|
|
892
|
+
The repository includes examples for both profiling and plotting.
|
|
893
|
+
|
|
894
|
+
### Profiling Example
|
|
895
|
+
|
|
896
|
+
See the `examples/profiling` directory for profiling examples, including `function.py`, `test.csv`, and the generated markdown report.
|
|
897
|
+
|
|
898
|
+
### Plotting Example
|
|
899
|
+
|
|
900
|
+
See the `examples/plotting` directory for plotting examples, including `generator.py`, `sample_data1.csv`, `sample_data2.csv`, and `sample_data3.csv`.
|
|
901
|
+
|
|
902
|
+
a multi GPU example comparing distributed FFT can be found here [jaxdecomp-bechmarks](https://github.com/ASKabalan/jaxdecomp-benchmarks)
|
jax_hpc_profiler-0.2.0/README.md
DELETED
|
@@ -1,159 +0,0 @@
|
|
|
1
|
-
# HPC Plotter
|
|
2
|
-
|
|
3
|
-
HPC Plotter is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
4
|
-
|
|
5
|
-
## Table of Contents
|
|
6
|
-
- [Introduction](#introduction)
|
|
7
|
-
- [Installation](#installation)
|
|
8
|
-
- [Generating CSV Files Using the Timer Class](#generating-csv-files-using-the-timer-class)
|
|
9
|
-
- [CSV Structure](#csv-structure)
|
|
10
|
-
- [Concatenating Files from Different Runs](#concatenating-files-from-different-runs)
|
|
11
|
-
- [Plotting CSV Data](#plotting-csv-data)
|
|
12
|
-
|
|
13
|
-
## Introduction
|
|
14
|
-
HPC Plotter allows users to:
|
|
15
|
-
1. Generate CSV files containing performance data.
|
|
16
|
-
2. Concatenate multiple CSV files from different runs.
|
|
17
|
-
3. Plot the performance data for analysis.
|
|
18
|
-
|
|
19
|
-
## Installation
|
|
20
|
-
|
|
21
|
-
To install the package, run the following command:
|
|
22
|
-
|
|
23
|
-
```bash
|
|
24
|
-
pip install hpc-plotter
|
|
25
|
-
```
|
|
26
|
-
## Generating CSV Files Using the Timer Class
|
|
27
|
-
|
|
28
|
-
To generate CSV files, you can use the `Timer` class provided in the `hpc_plotter.timer` module. This class helps in timing functions and saving the timing results to CSV files.
|
|
29
|
-
|
|
30
|
-
### Example Usage
|
|
31
|
-
|
|
32
|
-
```python
|
|
33
|
-
import time
|
|
34
|
-
from hpc_plotter.timer import Timer
|
|
35
|
-
import jax
|
|
36
|
-
# Define the functions you want to time
|
|
37
|
-
def example_function():
|
|
38
|
-
time.sleep(1) # Simulating a task
|
|
39
|
-
|
|
40
|
-
# Create a Timer instance
|
|
41
|
-
timer = Timer()
|
|
42
|
-
|
|
43
|
-
# Time the function
|
|
44
|
-
timer.chrono_jit(example_function)
|
|
45
|
-
for _ in range(5):
|
|
46
|
-
timer.chrono_fun(example_function)
|
|
47
|
-
|
|
48
|
-
# Metadata for the CSV file
|
|
49
|
-
metadata = {
|
|
50
|
-
'rank': jax.process_index(),
|
|
51
|
-
'function_name': 'example_function',
|
|
52
|
-
'precision': 'float32',
|
|
53
|
-
'x': '1024',
|
|
54
|
-
'y': '1024',
|
|
55
|
-
'z': '1024',
|
|
56
|
-
'px': '4',
|
|
57
|
-
'py': '4',
|
|
58
|
-
'backend': 'NCCL',
|
|
59
|
-
'nodes': '2'
|
|
60
|
-
}
|
|
61
|
-
|
|
62
|
-
# Print the results to a CSV file
|
|
63
|
-
timer.print_to_csv('output.csv', **metadata)
|
|
64
|
-
```
|
|
65
|
-
|
|
66
|
-
## CSV Structure
|
|
67
|
-
|
|
68
|
-
The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
### Example Directory Structure
|
|
72
|
-
|
|
73
|
-
```
|
|
74
|
-
root_directory/
|
|
75
|
-
├── gpu_1/
|
|
76
|
-
│ ├── 2/
|
|
77
|
-
│ │ ├── method_1.csv
|
|
78
|
-
│ │ ├── method_2.csv
|
|
79
|
-
│ │ └── method_3.csv
|
|
80
|
-
│ ├── 4/
|
|
81
|
-
│ │ ├── method_1.csv
|
|
82
|
-
│ │ ├── method_2.csv
|
|
83
|
-
│ │ └── method_3.csv
|
|
84
|
-
│ └── 8/
|
|
85
|
-
│ ├── method_1.csv
|
|
86
|
-
│ ├── method_2.csv
|
|
87
|
-
│ └── method_3.csv
|
|
88
|
-
└── gpu_2/
|
|
89
|
-
├── 2/
|
|
90
|
-
│ ├── method_1.csv
|
|
91
|
-
│ ├── method_2.csv
|
|
92
|
-
│ └── method_3.csv
|
|
93
|
-
├── 4/
|
|
94
|
-
│ ├── method_1.csv
|
|
95
|
-
│ ├── method_2.csv
|
|
96
|
-
│ └── method_3.csv
|
|
97
|
-
└── 8/
|
|
98
|
-
├── method_1.csv
|
|
99
|
-
├── method_2.csv
|
|
100
|
-
└── method_3.csv
|
|
101
|
-
```
|
|
102
|
-
|
|
103
|
-
## Concatenating Files from Different Runs
|
|
104
|
-
|
|
105
|
-
The `plot` function expects the directory to be organized as described above, but with the different number of GPUs toghether in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
|
|
106
|
-
|
|
107
|
-
### Example Usage
|
|
108
|
-
|
|
109
|
-
```bash
|
|
110
|
-
hpc-plotter concat /path/to/root_directory /path/to/output
|
|
111
|
-
```
|
|
112
|
-
|
|
113
|
-
And the output will be:
|
|
114
|
-
|
|
115
|
-
```
|
|
116
|
-
out_directory/
|
|
117
|
-
├── gpu_1/
|
|
118
|
-
│ ├── method_1.csv
|
|
119
|
-
│ ├── method_2.csv
|
|
120
|
-
│ └── method_3.csv
|
|
121
|
-
└── gpu_2/
|
|
122
|
-
├── method_1.csv
|
|
123
|
-
├── method_2.csv
|
|
124
|
-
└── method_3.csv
|
|
125
|
-
```
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
## Plotting CSV Data
|
|
130
|
-
|
|
131
|
-
You can plot the performance data using the `plot` command. The plotting command provides various options to customize the plots.
|
|
132
|
-
|
|
133
|
-
### Usage
|
|
134
|
-
|
|
135
|
-
```bash
|
|
136
|
-
hpc-plotter plot -f <csv_files> [options]
|
|
137
|
-
```
|
|
138
|
-
|
|
139
|
-
with options :
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
- `-f, --csv_files`: List of CSV files to plot (required).
|
|
143
|
-
- `-g, --gpus`: Filter GPUs. List of number of GPUs to plot.
|
|
144
|
-
- `-d, --data_size`: Filter data sizes. List of data sizes to plot.
|
|
145
|
-
- `-fd, --filter_pdims`: List of pdims to filter (e.g., 1x4 2x2 4x8).
|
|
146
|
-
- `-ps, --pdims_strategy`: Strategy for plotting pdims (`plot_all` or `plot_fastest`).
|
|
147
|
-
- `plot_all`: Plot every decomposition. 1xX and Xx1 as slabs, XxX as pencils.
|
|
148
|
-
- `plot_fastest`: Plot the fastest decomposition.
|
|
149
|
-
- `-p, --precision`: Precision to filter by (`float32` or `float64`).
|
|
150
|
-
- `-fn, --function_name`: Function name to filter.
|
|
151
|
-
- `-ta, --time_aggregation`: Time aggregation method (`mean`, `min`, `max`).
|
|
152
|
-
- `-tc, --time_column`: Time column to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_div`, `last_time`).
|
|
153
|
-
- `-fs, --figure_size`: Figure size.
|
|
154
|
-
- `-nl, --nodes_in_label`: Use node names in labels.
|
|
155
|
-
- `-o, --output`: Output file (if none then only show plot).
|
|
156
|
-
- `-db, --dark_bg`: Use dark background for plotting.
|
|
157
|
-
- `-pd, --print_decompositions`: Print decompositions on plot (only for `plot_fastest`).
|
|
158
|
-
- `-b, --backends`: List of backends to include.
|
|
159
|
-
- `-sc, --scaling`: Scaling type (`Weak` or `Strong`).
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
[build-system]
|
|
2
|
-
requires = ["setuptools", "wheel"]
|
|
3
|
-
build-backend = "setuptools.build_meta"
|
|
4
|
-
|
|
5
|
-
[project]
|
|
6
|
-
name = "jax_hpc_profiler"
|
|
7
|
-
version = "0.2.0"
|
|
8
|
-
description = "HPC Plotter and profiler for benchmarking data made for JAX"
|
|
9
|
-
authors = [
|
|
10
|
-
{ name="Wassim Kabalan" }
|
|
11
|
-
]
|
|
12
|
-
dependencies = [
|
|
13
|
-
"numpy",
|
|
14
|
-
"pandas",
|
|
15
|
-
"matplotlib",
|
|
16
|
-
"seaborn",
|
|
17
|
-
"tabulate"
|
|
18
|
-
]
|
|
19
|
-
readme = "README.md"
|
|
20
|
-
license = { file = "LICENSE" }
|
|
21
|
-
|
|
22
|
-
[project.scripts]
|
|
23
|
-
jhp = "jax_hpc_profiler.main:main"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.0 → jax_hpc_profiler-0.2.1}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|