jax-hpc-profiler 0.2.9__tar.gz → 0.2.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/PKG-INFO +3 -4
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/README.md +0 -2
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/pyproject.toml +1 -1
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/create_argparse.py +13 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/main.py +4 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/plotting.py +23 -14
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/timer.py +63 -32
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/utils.py +4 -3
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/PKG-INFO +3 -4
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.11
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -698,8 +698,7 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
-
|
|
702
|
-
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
701
|
+
Dynamic: license-file
|
|
703
702
|
|
|
704
703
|
# JAX HPC Profiler
|
|
705
704
|
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
2
|
-
|
|
3
1
|
# JAX HPC Profiler
|
|
4
2
|
|
|
5
3
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
@@ -168,6 +168,19 @@ def create_argparser():
|
|
|
168
168
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
169
169
|
)
|
|
170
170
|
|
|
171
|
+
plot_parser.add_argument(
|
|
172
|
+
"-xl",
|
|
173
|
+
"--xlabel",
|
|
174
|
+
type=str,
|
|
175
|
+
help="X-axis label for the plot",
|
|
176
|
+
)
|
|
177
|
+
plot_parser.add_argument(
|
|
178
|
+
"-tl",
|
|
179
|
+
"--title",
|
|
180
|
+
type=str,
|
|
181
|
+
help="Title for the plot",
|
|
182
|
+
)
|
|
183
|
+
|
|
171
184
|
subparsers.add_parser("label_help", help="Label customization help")
|
|
172
185
|
|
|
173
186
|
args = parser.parse_args()
|
|
@@ -37,6 +37,8 @@ def main():
|
|
|
37
37
|
args.plot_columns,
|
|
38
38
|
args.memory_units,
|
|
39
39
|
args.label_text,
|
|
40
|
+
args.title,
|
|
41
|
+
args.label_text,
|
|
40
42
|
args.figure_size,
|
|
41
43
|
args.dark_bg,
|
|
42
44
|
args.output,
|
|
@@ -55,6 +57,8 @@ def main():
|
|
|
55
57
|
args.plot_columns,
|
|
56
58
|
args.memory_units,
|
|
57
59
|
args.label_text,
|
|
60
|
+
args.title,
|
|
61
|
+
args.label_text,
|
|
58
62
|
args.figure_size,
|
|
59
63
|
args.dark_bg,
|
|
60
64
|
args.output,
|
|
@@ -17,8 +17,8 @@ def configure_axes(
|
|
|
17
17
|
ax: Axes,
|
|
18
18
|
x_values: List[int],
|
|
19
19
|
y_values: List[float],
|
|
20
|
-
xlabel: str,
|
|
21
20
|
title: str,
|
|
21
|
+
xlabel: str,
|
|
22
22
|
plotting_memory: bool = False,
|
|
23
23
|
memory_units: str = "bytes",
|
|
24
24
|
):
|
|
@@ -77,9 +77,9 @@ def plot_scaling(
|
|
|
77
77
|
output: Optional[str] = None,
|
|
78
78
|
dark_bg: bool = False,
|
|
79
79
|
print_decompositions: bool = False,
|
|
80
|
-
backends: List[str] =
|
|
81
|
-
precisions: List[str] =
|
|
82
|
-
functions: List[str]
|
|
80
|
+
backends: Optional[List[str]] = None,
|
|
81
|
+
precisions: Optional[List[str]] = None,
|
|
82
|
+
functions: Optional[List[str]] = None,
|
|
83
83
|
plot_columns: List[str] = ["mean_time"],
|
|
84
84
|
memory_units: str = "bytes",
|
|
85
85
|
label_text: str = "plot",
|
|
@@ -141,6 +141,11 @@ def plot_scaling(
|
|
|
141
141
|
by=[size_column])
|
|
142
142
|
functions = (pd.unique(filtered_method_df["function"])
|
|
143
143
|
if functions is None else functions)
|
|
144
|
+
precisions = (pd.unique(filtered_method_df["precision"])
|
|
145
|
+
if precisions is None else precisions)
|
|
146
|
+
backends = (pd.unique(filtered_method_df["backend"])
|
|
147
|
+
if backends is None else backends)
|
|
148
|
+
|
|
144
149
|
combinations = product(backends, precisions, functions,
|
|
145
150
|
plot_columns)
|
|
146
151
|
|
|
@@ -199,15 +204,17 @@ def plot_strong_scaling(
|
|
|
199
204
|
csv_files: List[str],
|
|
200
205
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
201
206
|
fixed_data_size: Optional[List[int]] = None,
|
|
202
|
-
functions: List[str]
|
|
203
|
-
precisions: List[str] =
|
|
207
|
+
functions: Optional[List[str]] = None,
|
|
208
|
+
precisions: Optional[List[str]] = None,
|
|
204
209
|
pdims: Optional[List[str]] = None,
|
|
205
210
|
pdims_strategy: List[str] = ["plot_fastest"],
|
|
206
211
|
print_decompositions: bool = False,
|
|
207
|
-
backends: List[str] =
|
|
212
|
+
backends: Optional[List[str]] = None,
|
|
208
213
|
plot_columns: List[str] = ["mean_time"],
|
|
209
214
|
memory_units: str = "bytes",
|
|
210
215
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
216
|
+
xlabel: str = "Number of GPUs",
|
|
217
|
+
title: str = "Data sizes",
|
|
211
218
|
figure_size: tuple = (6, 4),
|
|
212
219
|
dark_bg: bool = False,
|
|
213
220
|
output: Optional[str] = None,
|
|
@@ -236,8 +243,8 @@ def plot_strong_scaling(
|
|
|
236
243
|
available_data_sizes,
|
|
237
244
|
"gpus",
|
|
238
245
|
"x",
|
|
239
|
-
|
|
240
|
-
|
|
246
|
+
xlabel,
|
|
247
|
+
title,
|
|
241
248
|
figure_size,
|
|
242
249
|
output,
|
|
243
250
|
dark_bg,
|
|
@@ -256,15 +263,17 @@ def plot_weak_scaling(
|
|
|
256
263
|
csv_files: List[str],
|
|
257
264
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
258
265
|
fixed_data_size: Optional[List[int]] = None,
|
|
259
|
-
functions: List[str]
|
|
260
|
-
precisions: List[str] =
|
|
266
|
+
functions: Optional[List[str]] = None,
|
|
267
|
+
precisions: Optional[List[str]] = None,
|
|
261
268
|
pdims: Optional[List[str]] = None,
|
|
262
269
|
pdims_strategy: List[str] = ["plot_fastest"],
|
|
263
270
|
print_decompositions: bool = False,
|
|
264
|
-
backends: List[str] =
|
|
271
|
+
backends: Optional[List[str]] = None,
|
|
265
272
|
plot_columns: List[str] = ["mean_time"],
|
|
266
273
|
memory_units: str = "bytes",
|
|
267
274
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
275
|
+
xlabel: str = "Data sizes",
|
|
276
|
+
title: str = "Number of GPUs",
|
|
268
277
|
figure_size: tuple = (6, 4),
|
|
269
278
|
dark_bg: bool = False,
|
|
270
279
|
output: Optional[str] = None,
|
|
@@ -292,8 +301,8 @@ def plot_weak_scaling(
|
|
|
292
301
|
available_gpu_counts,
|
|
293
302
|
"x",
|
|
294
303
|
"gpus",
|
|
295
|
-
|
|
296
|
-
|
|
304
|
+
xlabel,
|
|
305
|
+
title,
|
|
297
306
|
figure_size,
|
|
298
307
|
output,
|
|
299
308
|
dark_bg,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable, List, Tuple
|
|
4
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
@@ -11,23 +11,31 @@ from jax.experimental import mesh_utils
|
|
|
11
11
|
from jax.experimental.shard_map import shard_map
|
|
12
12
|
from jax.sharding import Mesh, NamedSharding
|
|
13
13
|
from jax.sharding import PartitionSpec as P
|
|
14
|
+
from jaxtyping import Array
|
|
14
15
|
from tabulate import tabulate
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class Timer:
|
|
18
19
|
|
|
19
|
-
def __init__(self,
|
|
20
|
+
def __init__(self,
|
|
21
|
+
save_jaxpr=False,
|
|
22
|
+
compile_info=True,
|
|
23
|
+
jax_fn=True,
|
|
24
|
+
devices=None,
|
|
25
|
+
static_argnums=()):
|
|
20
26
|
self.jit_time = 0.0
|
|
21
27
|
self.times = []
|
|
22
28
|
self.profiling_data = {}
|
|
23
29
|
self.compiled_code = {}
|
|
24
30
|
self.save_jaxpr = save_jaxpr
|
|
31
|
+
self.compile_info = compile_info
|
|
25
32
|
self.jax_fn = jax_fn
|
|
26
33
|
self.devices = devices
|
|
34
|
+
self.static_argnums = static_argnums
|
|
27
35
|
|
|
28
36
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
29
37
|
|
|
30
|
-
if not self.jax_fn:
|
|
38
|
+
if not (self.jax_fn and self.compile_info):
|
|
31
39
|
return memory_analysis
|
|
32
40
|
|
|
33
41
|
sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
@@ -47,23 +55,36 @@ class Timer:
|
|
|
47
55
|
memory_analysis.temp_size_in_bytes,
|
|
48
56
|
)
|
|
49
57
|
|
|
50
|
-
def chrono_jit(self, fun: Callable, *args,
|
|
58
|
+
def chrono_jit(self, fun: Callable, *args, **kwargs) -> Array:
|
|
51
59
|
start = time.perf_counter()
|
|
52
|
-
out = fun(*args)
|
|
60
|
+
out = fun(*args, **kwargs)
|
|
53
61
|
if self.jax_fn:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
62
|
+
|
|
63
|
+
def _block(x):
|
|
64
|
+
if isinstance(x, Array):
|
|
65
|
+
x.block_until_ready()
|
|
66
|
+
|
|
67
|
+
jax.tree_map(_block, out)
|
|
58
68
|
end = time.perf_counter()
|
|
59
69
|
self.jit_time = (end - start) * 1e3
|
|
60
70
|
|
|
71
|
+
self.compiled_code["JAXPR"] = "N/A"
|
|
72
|
+
self.compiled_code["LOWERED"] = "N/A"
|
|
73
|
+
self.compiled_code["COMPILED"] = "N/A"
|
|
74
|
+
self.profiling_data["generated_code"] = "N/A"
|
|
75
|
+
self.profiling_data["argument_size"] = "N/A"
|
|
76
|
+
self.profiling_data["output_size"] = "N/A"
|
|
77
|
+
self.profiling_data["temp_size"] = "N/A"
|
|
78
|
+
|
|
61
79
|
if self.save_jaxpr:
|
|
62
|
-
jaxpr = make_jaxpr(fun
|
|
80
|
+
jaxpr = make_jaxpr(fun,
|
|
81
|
+
static_argnums=self.static_argnums)(*args,
|
|
82
|
+
**kwargs)
|
|
63
83
|
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
64
84
|
|
|
65
|
-
if self.jax_fn:
|
|
66
|
-
lowered = jax.jit(fun).lower(
|
|
85
|
+
if self.jax_fn and self.compile_info:
|
|
86
|
+
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
|
|
87
|
+
*args, **kwargs)
|
|
67
88
|
compiled = lowered.compile()
|
|
68
89
|
memory_analysis = self._read_memory_analysis(
|
|
69
90
|
compiled.memory_analysis())
|
|
@@ -77,19 +98,21 @@ class Timer:
|
|
|
77
98
|
|
|
78
99
|
return out
|
|
79
100
|
|
|
80
|
-
def chrono_fun(self, fun: Callable, *args,
|
|
101
|
+
def chrono_fun(self, fun: Callable, *args, **kwargs) -> Array:
|
|
81
102
|
start = time.perf_counter()
|
|
82
|
-
out = fun(*args)
|
|
103
|
+
out = fun(*args, **kwargs)
|
|
83
104
|
if self.jax_fn:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
105
|
+
|
|
106
|
+
def _block(x):
|
|
107
|
+
if isinstance(x, Array):
|
|
108
|
+
x.block_until_ready()
|
|
109
|
+
|
|
110
|
+
jax.tree_map(_block, out)
|
|
88
111
|
end = time.perf_counter()
|
|
89
112
|
self.times.append((end - start) * 1e3)
|
|
90
113
|
return out
|
|
91
114
|
|
|
92
|
-
def _get_mean_times(self) ->
|
|
115
|
+
def _get_mean_times(self) -> Array:
|
|
93
116
|
if jax.device_count() == 1 or jax.process_count() == 1:
|
|
94
117
|
return np.array(self.times)
|
|
95
118
|
|
|
@@ -133,8 +156,12 @@ class Timer:
|
|
|
133
156
|
backend: str = "NCCL",
|
|
134
157
|
nodes: int = 1,
|
|
135
158
|
md_filename: str | None = None,
|
|
159
|
+
npz_data: Optional[dict] = None,
|
|
136
160
|
extra_info: dict = {},
|
|
137
161
|
):
|
|
162
|
+
if self.jit_time == 0.0 and len(self.times) == 0:
|
|
163
|
+
print(f"No profiling data to report for {function}")
|
|
164
|
+
return
|
|
138
165
|
|
|
139
166
|
if md_filename is None:
|
|
140
167
|
dirname, filename = (
|
|
@@ -147,6 +174,18 @@ class Timer:
|
|
|
147
174
|
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
148
175
|
)
|
|
149
176
|
|
|
177
|
+
if npz_data is not None:
|
|
178
|
+
dirname, filename = (
|
|
179
|
+
os.path.dirname(csv_filename),
|
|
180
|
+
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
181
|
+
)
|
|
182
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
183
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
184
|
+
npz_filename = (
|
|
185
|
+
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
|
|
186
|
+
)
|
|
187
|
+
np.savez(npz_filename, **npz_data)
|
|
188
|
+
|
|
150
189
|
y = x if y is None else y
|
|
151
190
|
z = x if z is None else z
|
|
152
191
|
|
|
@@ -158,18 +197,10 @@ class Timer:
|
|
|
158
197
|
mean_time = np.mean(times_array)
|
|
159
198
|
std_time = np.std(times_array)
|
|
160
199
|
last_time = times_array[-1]
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
argument_size = self.profiling_data["argument_size"]
|
|
166
|
-
output_size = self.profiling_data["output_size"]
|
|
167
|
-
temp_size = self.profiling_data["temp_size"]
|
|
168
|
-
else:
|
|
169
|
-
generated_code = "N/A"
|
|
170
|
-
argument_size = "N/A"
|
|
171
|
-
output_size = "N/A"
|
|
172
|
-
temp_size = "N/A"
|
|
200
|
+
generated_code = self.profiling_data["generated_code"]
|
|
201
|
+
argument_size = self.profiling_data["argument_size"]
|
|
202
|
+
output_size = self.profiling_data["output_size"]
|
|
203
|
+
temp_size = self.profiling_data["temp_size"]
|
|
173
204
|
|
|
174
205
|
csv_line = (
|
|
175
206
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
@@ -233,7 +264,7 @@ class Timer:
|
|
|
233
264
|
headers=["Iteration", "Time"],
|
|
234
265
|
tablefmt="github",
|
|
235
266
|
))
|
|
236
|
-
if self.jax_fn:
|
|
267
|
+
if self.jax_fn and self.compile_info:
|
|
237
268
|
f.write("\n---\n")
|
|
238
269
|
f.write(f"## Compiled Code\n")
|
|
239
270
|
f.write(f"```hlo\n")
|
|
@@ -248,7 +248,7 @@ def clean_up_csv(
|
|
|
248
248
|
data_sizes: Optional[List[int]] = None,
|
|
249
249
|
pdims: Optional[List[str]] = None,
|
|
250
250
|
pdims_strategy: List[str] = ['plot_fastest'],
|
|
251
|
-
backends: List[str] =
|
|
251
|
+
backends: Optional[List[str]] = None,
|
|
252
252
|
memory_units: str = 'KB',
|
|
253
253
|
) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
|
|
254
254
|
"""
|
|
@@ -292,7 +292,7 @@ def clean_up_csv(
|
|
|
292
292
|
|
|
293
293
|
df = pd.read_csv(csv_file,
|
|
294
294
|
header=None,
|
|
295
|
-
skiprows=
|
|
295
|
+
skiprows=0,
|
|
296
296
|
names=[
|
|
297
297
|
"function", "precision", "x", "y", "z", "px",
|
|
298
298
|
"py", "backend", "nodes", "jit_time", "min_time",
|
|
@@ -331,7 +331,8 @@ def clean_up_csv(
|
|
|
331
331
|
if function_names:
|
|
332
332
|
df = df[df['function'].isin(function_names)]
|
|
333
333
|
# Filter backends
|
|
334
|
-
|
|
334
|
+
if backends:
|
|
335
|
+
df = df[df['backend'].isin(backends)]
|
|
335
336
|
|
|
336
337
|
# Filter data sizes
|
|
337
338
|
if data_sizes:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.11
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -698,8 +698,7 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
-
|
|
702
|
-
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
701
|
+
Dynamic: license-file
|
|
703
702
|
|
|
704
703
|
# JAX HPC Profiler
|
|
705
704
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|