jax-hpc-profiler 0.2.8__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/pyproject.toml +1 -1
  3. jax_hpc_profiler-0.2.9/src/jax_hpc_profiler/create_argparse.py +197 -0
  4. jax_hpc_profiler-0.2.9/src/jax_hpc_profiler/main.py +65 -0
  5. jax_hpc_profiler-0.2.9/src/jax_hpc_profiler/plotting.py +308 -0
  6. jax_hpc_profiler-0.2.9/src/jax_hpc_profiler/timer.py +258 -0
  7. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler/utils.py +15 -2
  8. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  9. jax_hpc_profiler-0.2.8/src/jax_hpc_profiler/create_argparse.py +0 -161
  10. jax_hpc_profiler-0.2.8/src/jax_hpc_profiler/main.py +0 -63
  11. jax_hpc_profiler-0.2.8/src/jax_hpc_profiler/plotting.py +0 -217
  12. jax_hpc_profiler-0.2.8/src/jax_hpc_profiler/timer.py +0 -218
  13. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/LICENSE +0 -0
  14. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/README.md +0 -0
  15. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/setup.cfg +0 -0
  16. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler/__init__.py +0 -0
  17. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  18. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  19. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  20. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  21. {jax_hpc_profiler-0.2.8 → jax_hpc_profiler-0.2.9}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.8"
7
+ version = "0.2.9"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -0,0 +1,197 @@
1
+ import argparse
2
+
3
+
4
+ def create_argparser():
5
+ """
6
+ Create argument parser for the HPC Plotter package.
7
+
8
+ Returns
9
+ -------
10
+ argparse.Namespace
11
+ Parsed and validated arguments.
12
+ """
13
+ parser = argparse.ArgumentParser(
14
+ description="HPC Plotter for benchmarking data")
15
+
16
+ # Group for concatenation to ensure mutually exclusive behavior
17
+ subparsers = parser.add_subparsers(dest="command", required=True)
18
+
19
+ concat_parser = subparsers.add_parser("concat",
20
+ help="Concatenate CSV files")
21
+ concat_parser.add_argument("input",
22
+ type=str,
23
+ help="Input directory for concatenation")
24
+ concat_parser.add_argument("output",
25
+ type=str,
26
+ help="Output directory for concatenation")
27
+
28
+ # Arguments for plotting
29
+ plot_parser = subparsers.add_parser("plot", help="Plot CSV data")
30
+ plot_parser.add_argument("-f",
31
+ "--csv_files",
32
+ nargs="+",
33
+ help="List of CSV files to plot",
34
+ required=True)
35
+ plot_parser.add_argument(
36
+ "-g",
37
+ "--gpus",
38
+ nargs="*",
39
+ type=int,
40
+ help="List of number of GPUs to plot",
41
+ default=None,
42
+ )
43
+ plot_parser.add_argument(
44
+ "-d",
45
+ "--data_size",
46
+ nargs="*",
47
+ type=int,
48
+ help="List of data sizes to plot",
49
+ default=None,
50
+ )
51
+
52
+ # pdims related arguments
53
+ plot_parser.add_argument(
54
+ "-fd",
55
+ "--filter_pdims",
56
+ nargs="*",
57
+ help="List of pdims to filter, e.g., 1x4 2x2 4x8",
58
+ default=None,
59
+ )
60
+ plot_parser.add_argument(
61
+ "-ps",
62
+ "--pdim_strategy",
63
+ choices=["plot_all", "plot_fastest", "slab_yz", "slab_xy", "pencils"],
64
+ nargs="*",
65
+ default=["plot_fastest"],
66
+ help="Strategy for plotting pdims",
67
+ )
68
+
69
+ # Function and precision related arguments
70
+ plot_parser.add_argument(
71
+ "-pr",
72
+ "--precision",
73
+ choices=["float32", "float64"],
74
+ default=["float32", "float64"],
75
+ nargs="*",
76
+ help="Precision to filter by (float32 or float64)",
77
+ )
78
+ plot_parser.add_argument(
79
+ "-fn",
80
+ "--function_name",
81
+ nargs="+",
82
+ help="Function names to filter",
83
+ default=None,
84
+ )
85
+
86
+ # Time or memory related arguments
87
+ plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
88
+ plotting_group.add_argument(
89
+ "-pt",
90
+ "--plot_times",
91
+ nargs="*",
92
+ choices=[
93
+ "jit_time",
94
+ "min_time",
95
+ "max_time",
96
+ "mean_time",
97
+ "std_time",
98
+ "last_time",
99
+ ],
100
+ help="Time columns to plot",
101
+ )
102
+ plotting_group.add_argument(
103
+ "-pm",
104
+ "--plot_memory",
105
+ nargs="*",
106
+ choices=[
107
+ "generated_code", "argument_size", "output_size", "temp_size"
108
+ ],
109
+ help="Memory columns to plot",
110
+ )
111
+ plot_parser.add_argument(
112
+ "-mu",
113
+ "--memory_units",
114
+ default="GB",
115
+ help="Memory units to plot (KB, MB, GB, TB)",
116
+ )
117
+
118
+ # Plot customization arguments
119
+ plot_parser.add_argument("-fs",
120
+ "--figure_size",
121
+ nargs=2,
122
+ type=int,
123
+ help="Figure size",
124
+ default=(10, 6))
125
+ plot_parser.add_argument("-o",
126
+ "--output",
127
+ help="Output file (if none then only show plot)",
128
+ default=None)
129
+ plot_parser.add_argument("-db",
130
+ "--dark_bg",
131
+ action="store_true",
132
+ help="Use dark background for plotting")
133
+ plot_parser.add_argument(
134
+ "-pd",
135
+ "--print_decompositions",
136
+ action="store_true",
137
+ help="Print decompositions on plot",
138
+ )
139
+
140
+ # Backend related arguments
141
+ plot_parser.add_argument(
142
+ "-b",
143
+ "--backends",
144
+ nargs="*",
145
+ default=["MPI", "NCCL", "MPI4JAX"],
146
+ help="List of backends to include",
147
+ )
148
+
149
+ # Scaling type argument
150
+ plot_parser.add_argument(
151
+ "-sc",
152
+ "--scaling",
153
+ choices=["Weak", "Strong", "w", "s"],
154
+ required=True,
155
+ help="Scaling type (Weak or Strong)",
156
+ )
157
+
158
+ # Label customization argument
159
+ plot_parser.add_argument(
160
+ "-l",
161
+ "--label_text",
162
+ type=str,
163
+ help=
164
+ ("Custom label for the plot. You can use placeholders: %%decomposition%% "
165
+ "(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), "
166
+ "%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)"
167
+ ),
168
+ default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
169
+ )
170
+
171
+ subparsers.add_parser("label_help", help="Label customization help")
172
+
173
+ args = parser.parse_args()
174
+
175
+ # if command was plot, then check if pdim_strategy is validat
176
+ if args.command == "plot":
177
+ if "plot_all" in args.pdim_strategy and len(args.pdim_strategy) > 1:
178
+ print(
179
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
180
+ )
181
+ args.pdim_strategy = ["plot_all"]
182
+
183
+ if "plot_fastest" in args.pdim_strategy and len(
184
+ args.pdim_strategy) > 1:
185
+ print(
186
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
187
+ )
188
+ args.pdim_strategy = ["plot_fastest"]
189
+ if args.plot_times is not None:
190
+ args.plot_columns = args.plot_times
191
+ elif args.plot_memory is not None:
192
+ args.plot_columns = args.plot_memory
193
+ else:
194
+ raise ValueError(
195
+ "Either plot_times or plot_memory should be provided")
196
+
197
+ return args
@@ -0,0 +1,65 @@
1
+ import sys
2
+ from typing import List, Optional
3
+
4
+ from .create_argparse import create_argparser
5
+ from .plotting import plot_strong_scaling, plot_weak_scaling
6
+ from .utils import clean_up_csv, concatenate_csvs
7
+
8
+
9
+ def main():
10
+ args = create_argparser()
11
+
12
+ if args.command == "concat":
13
+ input_dir, output_dir = args.input, args.output
14
+ concatenate_csvs(input_dir, output_dir)
15
+ elif args.command == "label_help":
16
+ print(f"Customize the label text for the plot. using these commands.")
17
+ print(" -- %m% or %methodname%: method name")
18
+ print(" -- %f% or %function%: function name")
19
+ print(" -- %pn% or %plot_name%: plot name")
20
+ print(" -- %pr% or %precision%: precision")
21
+ print(" -- %b% or %backend%: backend")
22
+ print(" -- %p% or %pdims%: pdims")
23
+ print(" -- %n% or %node%: node")
24
+ elif args.command == "plot":
25
+
26
+ if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
27
+ plot_weak_scaling(
28
+ args.csv_files,
29
+ args.gpus,
30
+ args.data_size,
31
+ args.function_name,
32
+ args.precision,
33
+ args.filter_pdims,
34
+ args.pdim_strategy,
35
+ args.print_decompositions,
36
+ args.backends,
37
+ args.plot_columns,
38
+ args.memory_units,
39
+ args.label_text,
40
+ args.figure_size,
41
+ args.dark_bg,
42
+ args.output,
43
+ )
44
+ elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
45
+ plot_strong_scaling(
46
+ args.csv_files,
47
+ args.gpus,
48
+ args.data_size,
49
+ args.function_name,
50
+ args.precision,
51
+ args.filter_pdims,
52
+ args.pdim_strategy,
53
+ args.print_decompositions,
54
+ args.backends,
55
+ args.plot_columns,
56
+ args.memory_units,
57
+ args.label_text,
58
+ args.figure_size,
59
+ args.dark_bg,
60
+ args.output,
61
+ )
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
@@ -0,0 +1,308 @@
1
+ from itertools import product
2
+ from typing import Dict, List, Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
+ from matplotlib.axes import Axes
9
+ from matplotlib.patches import FancyBboxPatch
10
+
11
+ from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
12
+
13
+ np.seterr(divide="ignore")
14
+
15
+
16
+ def configure_axes(
17
+ ax: Axes,
18
+ x_values: List[int],
19
+ y_values: List[float],
20
+ xlabel: str,
21
+ title: str,
22
+ plotting_memory: bool = False,
23
+ memory_units: str = "bytes",
24
+ ):
25
+ """
26
+ Configure the axes for the plot.
27
+
28
+ Parameters
29
+ ----------
30
+ ax : Axes
31
+ The axes to configure.
32
+ x_values : List[int]
33
+ The x-axis values.
34
+ y_values : List[float]
35
+ The y-axis values.
36
+ xlabel : str
37
+ The label for the x-axis.
38
+ """
39
+ ylabel = ("Time (milliseconds)"
40
+ if not plotting_memory else f"Memory ({memory_units})")
41
+ f2 = lambda x: np.log2(x)
42
+ g2 = lambda x: 2**x
43
+ ax.set_xlim([min(x_values), max(x_values)])
44
+ y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
45
+ ax.set_title(title)
46
+ ax.set_ylim([y_min, y_max])
47
+ ax.set_xscale("function", functions=(f2, g2))
48
+ if not plotting_memory:
49
+ ax.set_yscale("symlog")
50
+ time_ticks = [
51
+ 10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
52
+ int(np.ceil(np.log10(y_max))))
53
+ ]
54
+ ax.set_yticks(time_ticks)
55
+ ax.set_xticks(x_values)
56
+ ax.set_xlabel(xlabel)
57
+ ax.set_ylabel(ylabel)
58
+ for x_value in x_values:
59
+ ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
60
+ ax.legend(
61
+ loc="lower center",
62
+ bbox_to_anchor=(0.5, 0.05),
63
+ ncol=4,
64
+ fontsize="x-large",
65
+ prop={"size": 14},
66
+ )
67
+
68
+
69
+ def plot_scaling(
70
+ dataframes: Dict[str, pd.DataFrame],
71
+ fixed_sizes: List[int],
72
+ size_column: str,
73
+ fixed_column: str,
74
+ xlabel: str,
75
+ title: str,
76
+ figure_size: tuple = (6, 4),
77
+ output: Optional[str] = None,
78
+ dark_bg: bool = False,
79
+ print_decompositions: bool = False,
80
+ backends: List[str] = ["NCCL"],
81
+ precisions: List[str] = ["float32"],
82
+ functions: List[str] | None = None,
83
+ plot_columns: List[str] = ["mean_time"],
84
+ memory_units: str = "bytes",
85
+ label_text: str = "plot",
86
+ pdims_strategy: List[str] = ["plot_fastest"],
87
+ ):
88
+ """
89
+ General scaling plot function based on the number of GPUs or data size.
90
+
91
+ Parameters
92
+ ----------
93
+ dataframes : Dict[str, pd.DataFrame]
94
+ Dictionary of method names to dataframes.
95
+ fixed_sizes : List[int]
96
+ List of fixed sizes (data or GPUs) to plot.
97
+ size_column : str
98
+ Column name for the size axis ('x' for weak scaling, 'gpus' for strong scaling).
99
+ fixed_column : str
100
+ Column name for the fixed axis ('gpus' for weak scaling, 'x' for strong scaling).
101
+ xlabel : str
102
+ Label for the x-axis.
103
+ figure_size : tuple, optional
104
+ Size of the figure, by default (6, 4).
105
+ output : Optional[str], optional
106
+ Output file to save the plot, by default None.
107
+ dark_bg : bool, optional
108
+ Whether to use dark background for the plot, by default False.
109
+ print_decompositions : bool, optional
110
+ Whether to print decompositions on the plot, by default False.
111
+ backends : Optional[List[str]], optional
112
+ List of backends to include, by default None.
113
+ pdims_strategy : str, optional
114
+ Strategy for plotting pdims ('plot_all' or 'plot_fastest'), by default 'plot_fastest'.
115
+ """
116
+
117
+ if dark_bg:
118
+ plt.style.use("dark_background")
119
+
120
+ num_subplots = len(fixed_sizes)
121
+ num_rows = int(np.ceil(np.sqrt(num_subplots)))
122
+ num_cols = int(np.ceil(num_subplots / num_rows))
123
+
124
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=figure_size)
125
+ if num_subplots > 1:
126
+ axs = axs.flatten()
127
+ else:
128
+ axs = [axs]
129
+
130
+ for i, fixed_size in enumerate(fixed_sizes):
131
+ ax: Axes = axs[i]
132
+
133
+ x_values = []
134
+ y_values = []
135
+ for method, df in dataframes.items():
136
+
137
+ filtered_method_df = df[df[fixed_column] == int(fixed_size)]
138
+ if filtered_method_df.empty:
139
+ continue
140
+ filtered_method_df = filtered_method_df.sort_values(
141
+ by=[size_column])
142
+ functions = (pd.unique(filtered_method_df["function"])
143
+ if functions is None else functions)
144
+ combinations = product(backends, precisions, functions,
145
+ plot_columns)
146
+
147
+ for backend, precision, function, plot_column in combinations:
148
+
149
+ filtered_params_df = filtered_method_df[
150
+ (filtered_method_df["backend"] == backend)
151
+ & (filtered_method_df["precision"] == precision)
152
+ & (filtered_method_df["function"] == function)]
153
+ if filtered_params_df.empty:
154
+ continue
155
+ x_vals, y_vals = plot_with_pdims_strategy(
156
+ ax,
157
+ filtered_params_df,
158
+ method,
159
+ pdims_strategy,
160
+ print_decompositions,
161
+ size_column,
162
+ plot_column,
163
+ label_text,
164
+ )
165
+
166
+ x_values.extend(x_vals)
167
+ y_values.extend(y_vals)
168
+
169
+ if len(x_values) != 0:
170
+ plotting_memory = "time" not in plot_columns[0].lower()
171
+ configure_axes(
172
+ ax,
173
+ x_values,
174
+ y_values,
175
+ f"{title} {fixed_size}",
176
+ xlabel,
177
+ plotting_memory,
178
+ memory_units,
179
+ )
180
+
181
+ for i in range(num_subplots, num_rows * num_cols):
182
+ fig.delaxes(axs[i])
183
+
184
+ fig.tight_layout()
185
+ rect = FancyBboxPatch((0.1, 0.1),
186
+ 0.8,
187
+ 0.8,
188
+ boxstyle="round,pad=0.02",
189
+ ec="black",
190
+ fc="none")
191
+ fig.patches.append(rect)
192
+ if output is None:
193
+ plt.show()
194
+ else:
195
+ plt.savefig(output, bbox_inches="tight", transparent=True)
196
+
197
+
198
+ def plot_strong_scaling(
199
+ csv_files: List[str],
200
+ fixed_gpu_size: Optional[List[int]] = None,
201
+ fixed_data_size: Optional[List[int]] = None,
202
+ functions: List[str] | None = None,
203
+ precisions: List[str] = ["float32"],
204
+ pdims: Optional[List[str]] = None,
205
+ pdims_strategy: List[str] = ["plot_fastest"],
206
+ print_decompositions: bool = False,
207
+ backends: List[str] = ["NCCL"],
208
+ plot_columns: List[str] = ["mean_time"],
209
+ memory_units: str = "bytes",
210
+ label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
211
+ figure_size: tuple = (6, 4),
212
+ dark_bg: bool = False,
213
+ output: Optional[str] = None,
214
+ ):
215
+ """
216
+ Plot strong scaling based on the number of GPUs.
217
+ """
218
+
219
+ dataframes, _, available_data_sizes = clean_up_csv(
220
+ csv_files,
221
+ precisions,
222
+ functions,
223
+ fixed_gpu_size,
224
+ fixed_data_size,
225
+ pdims,
226
+ pdims_strategy,
227
+ backends,
228
+ memory_units,
229
+ )
230
+ if len(dataframes) == 0:
231
+ print(f"No dataframes found for the given arguments. Exiting...")
232
+ return
233
+
234
+ plot_scaling(
235
+ dataframes,
236
+ available_data_sizes,
237
+ "gpus",
238
+ "x",
239
+ "Number of GPUs",
240
+ "Data size",
241
+ figure_size,
242
+ output,
243
+ dark_bg,
244
+ print_decompositions,
245
+ backends,
246
+ precisions,
247
+ functions,
248
+ plot_columns,
249
+ memory_units,
250
+ label_text,
251
+ pdims_strategy,
252
+ )
253
+
254
+
255
+ def plot_weak_scaling(
256
+ csv_files: List[str],
257
+ fixed_gpu_size: Optional[List[int]] = None,
258
+ fixed_data_size: Optional[List[int]] = None,
259
+ functions: List[str] | None = None,
260
+ precisions: List[str] = ["float32"],
261
+ pdims: Optional[List[str]] = None,
262
+ pdims_strategy: List[str] = ["plot_fastest"],
263
+ print_decompositions: bool = False,
264
+ backends: List[str] = ["NCCL"],
265
+ plot_columns: List[str] = ["mean_time"],
266
+ memory_units: str = "bytes",
267
+ label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
268
+ figure_size: tuple = (6, 4),
269
+ dark_bg: bool = False,
270
+ output: Optional[str] = None,
271
+ ):
272
+ """
273
+ Plot weak scaling based on the data size.
274
+ """
275
+ dataframes, available_gpu_counts, _ = clean_up_csv(
276
+ csv_files,
277
+ precisions,
278
+ functions,
279
+ fixed_gpu_size,
280
+ fixed_data_size,
281
+ pdims,
282
+ pdims_strategy,
283
+ backends,
284
+ memory_units,
285
+ )
286
+ if len(dataframes) == 0:
287
+ print(f"No dataframes found for the given arguments. Exiting...")
288
+ return
289
+
290
+ plot_scaling(
291
+ dataframes,
292
+ available_gpu_counts,
293
+ "x",
294
+ "gpus",
295
+ "Data size",
296
+ "Number of GPUs",
297
+ figure_size,
298
+ output,
299
+ dark_bg,
300
+ print_decompositions,
301
+ backends,
302
+ precisions,
303
+ functions,
304
+ plot_columns,
305
+ memory_units,
306
+ label_text,
307
+ pdims_strategy,
308
+ )