jax-hpc-profiler 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/create_argparse.py +13 -0
- jax_hpc_profiler/main.py +4 -0
- jax_hpc_profiler/plotting.py +9 -5
- jax_hpc_profiler/timer.py +46 -31
- jax_hpc_profiler/utils.py +1 -1
- {jax_hpc_profiler-0.2.10.dist-info → jax_hpc_profiler-0.2.12.dist-info}/METADATA +3 -4
- jax_hpc_profiler-0.2.12.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.10.dist-info → jax_hpc_profiler-0.2.12.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.10.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.10.dist-info → jax_hpc_profiler-0.2.12.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.10.dist-info → jax_hpc_profiler-0.2.12.dist-info/licenses}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.10.dist-info → jax_hpc_profiler-0.2.12.dist-info}/top_level.txt +0 -0
|
@@ -168,6 +168,19 @@ def create_argparser():
|
|
|
168
168
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
169
169
|
)
|
|
170
170
|
|
|
171
|
+
plot_parser.add_argument(
|
|
172
|
+
"-xl",
|
|
173
|
+
"--xlabel",
|
|
174
|
+
type=str,
|
|
175
|
+
help="X-axis label for the plot",
|
|
176
|
+
)
|
|
177
|
+
plot_parser.add_argument(
|
|
178
|
+
"-tl",
|
|
179
|
+
"--title",
|
|
180
|
+
type=str,
|
|
181
|
+
help="Title for the plot",
|
|
182
|
+
)
|
|
183
|
+
|
|
171
184
|
subparsers.add_parser("label_help", help="Label customization help")
|
|
172
185
|
|
|
173
186
|
args = parser.parse_args()
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -37,6 +37,8 @@ def main():
|
|
|
37
37
|
args.plot_columns,
|
|
38
38
|
args.memory_units,
|
|
39
39
|
args.label_text,
|
|
40
|
+
args.title,
|
|
41
|
+
args.label_text,
|
|
40
42
|
args.figure_size,
|
|
41
43
|
args.dark_bg,
|
|
42
44
|
args.output,
|
|
@@ -55,6 +57,8 @@ def main():
|
|
|
55
57
|
args.plot_columns,
|
|
56
58
|
args.memory_units,
|
|
57
59
|
args.label_text,
|
|
60
|
+
args.title,
|
|
61
|
+
args.label_text,
|
|
58
62
|
args.figure_size,
|
|
59
63
|
args.dark_bg,
|
|
60
64
|
args.output,
|
jax_hpc_profiler/plotting.py
CHANGED
|
@@ -17,8 +17,8 @@ def configure_axes(
|
|
|
17
17
|
ax: Axes,
|
|
18
18
|
x_values: List[int],
|
|
19
19
|
y_values: List[float],
|
|
20
|
-
xlabel: str,
|
|
21
20
|
title: str,
|
|
21
|
+
xlabel: str,
|
|
22
22
|
plotting_memory: bool = False,
|
|
23
23
|
memory_units: str = "bytes",
|
|
24
24
|
):
|
|
@@ -213,6 +213,8 @@ def plot_strong_scaling(
|
|
|
213
213
|
plot_columns: List[str] = ["mean_time"],
|
|
214
214
|
memory_units: str = "bytes",
|
|
215
215
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
216
|
+
xlabel: str = "Number of GPUs",
|
|
217
|
+
title: str = "Data sizes",
|
|
216
218
|
figure_size: tuple = (6, 4),
|
|
217
219
|
dark_bg: bool = False,
|
|
218
220
|
output: Optional[str] = None,
|
|
@@ -241,8 +243,8 @@ def plot_strong_scaling(
|
|
|
241
243
|
available_data_sizes,
|
|
242
244
|
"gpus",
|
|
243
245
|
"x",
|
|
244
|
-
|
|
245
|
-
|
|
246
|
+
xlabel,
|
|
247
|
+
title,
|
|
246
248
|
figure_size,
|
|
247
249
|
output,
|
|
248
250
|
dark_bg,
|
|
@@ -270,6 +272,8 @@ def plot_weak_scaling(
|
|
|
270
272
|
plot_columns: List[str] = ["mean_time"],
|
|
271
273
|
memory_units: str = "bytes",
|
|
272
274
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
275
|
+
xlabel: str = "Data sizes",
|
|
276
|
+
title: str = "Number of GPUs",
|
|
273
277
|
figure_size: tuple = (6, 4),
|
|
274
278
|
dark_bg: bool = False,
|
|
275
279
|
output: Optional[str] = None,
|
|
@@ -297,8 +301,8 @@ def plot_weak_scaling(
|
|
|
297
301
|
available_gpu_counts,
|
|
298
302
|
"x",
|
|
299
303
|
"gpus",
|
|
300
|
-
|
|
301
|
-
|
|
304
|
+
xlabel,
|
|
305
|
+
title,
|
|
302
306
|
figure_size,
|
|
303
307
|
output,
|
|
304
308
|
dark_bg,
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -11,23 +11,31 @@ from jax.experimental import mesh_utils
|
|
|
11
11
|
from jax.experimental.shard_map import shard_map
|
|
12
12
|
from jax.sharding import Mesh, NamedSharding
|
|
13
13
|
from jax.sharding import PartitionSpec as P
|
|
14
|
+
from jaxtyping import Array
|
|
14
15
|
from tabulate import tabulate
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class Timer:
|
|
18
19
|
|
|
19
|
-
def __init__(self,
|
|
20
|
+
def __init__(self,
|
|
21
|
+
save_jaxpr=False,
|
|
22
|
+
compile_info=True,
|
|
23
|
+
jax_fn=True,
|
|
24
|
+
devices=None,
|
|
25
|
+
static_argnums=()):
|
|
20
26
|
self.jit_time = 0.0
|
|
21
27
|
self.times = []
|
|
22
28
|
self.profiling_data = {}
|
|
23
29
|
self.compiled_code = {}
|
|
24
30
|
self.save_jaxpr = save_jaxpr
|
|
31
|
+
self.compile_info = compile_info
|
|
25
32
|
self.jax_fn = jax_fn
|
|
26
33
|
self.devices = devices
|
|
34
|
+
self.static_argnums = static_argnums
|
|
27
35
|
|
|
28
36
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
29
37
|
|
|
30
|
-
if not self.jax_fn:
|
|
38
|
+
if not (self.jax_fn and self.compile_info):
|
|
31
39
|
return memory_analysis
|
|
32
40
|
|
|
33
41
|
sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
@@ -47,23 +55,36 @@ class Timer:
|
|
|
47
55
|
memory_analysis.temp_size_in_bytes,
|
|
48
56
|
)
|
|
49
57
|
|
|
50
|
-
def chrono_jit(self, fun: Callable, *args,
|
|
58
|
+
def chrono_jit(self, fun: Callable, *args, **kwargs) -> Array:
|
|
51
59
|
start = time.perf_counter()
|
|
52
|
-
out = fun(*args)
|
|
60
|
+
out = fun(*args, **kwargs)
|
|
53
61
|
if self.jax_fn:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
62
|
+
|
|
63
|
+
def _block(x):
|
|
64
|
+
if isinstance(x, Array):
|
|
65
|
+
x.block_until_ready()
|
|
66
|
+
|
|
67
|
+
jax.tree_map(_block, out)
|
|
58
68
|
end = time.perf_counter()
|
|
59
69
|
self.jit_time = (end - start) * 1e3
|
|
60
70
|
|
|
71
|
+
self.compiled_code["JAXPR"] = "N/A"
|
|
72
|
+
self.compiled_code["LOWERED"] = "N/A"
|
|
73
|
+
self.compiled_code["COMPILED"] = "N/A"
|
|
74
|
+
self.profiling_data["generated_code"] = "N/A"
|
|
75
|
+
self.profiling_data["argument_size"] = "N/A"
|
|
76
|
+
self.profiling_data["output_size"] = "N/A"
|
|
77
|
+
self.profiling_data["temp_size"] = "N/A"
|
|
78
|
+
|
|
61
79
|
if self.save_jaxpr:
|
|
62
|
-
jaxpr = make_jaxpr(fun
|
|
80
|
+
jaxpr = make_jaxpr(fun,
|
|
81
|
+
static_argnums=self.static_argnums)(*args,
|
|
82
|
+
**kwargs)
|
|
63
83
|
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
64
84
|
|
|
65
|
-
if self.jax_fn:
|
|
66
|
-
lowered = jax.jit(fun).lower(
|
|
85
|
+
if self.jax_fn and self.compile_info:
|
|
86
|
+
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
|
|
87
|
+
*args, **kwargs)
|
|
67
88
|
compiled = lowered.compile()
|
|
68
89
|
memory_analysis = self._read_memory_analysis(
|
|
69
90
|
compiled.memory_analysis())
|
|
@@ -77,19 +98,21 @@ class Timer:
|
|
|
77
98
|
|
|
78
99
|
return out
|
|
79
100
|
|
|
80
|
-
def chrono_fun(self, fun: Callable, *args,
|
|
101
|
+
def chrono_fun(self, fun: Callable, *args, **kwargs) -> Array:
|
|
81
102
|
start = time.perf_counter()
|
|
82
|
-
out = fun(*args)
|
|
103
|
+
out = fun(*args, **kwargs)
|
|
83
104
|
if self.jax_fn:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
105
|
+
|
|
106
|
+
def _block(x):
|
|
107
|
+
if isinstance(x, Array):
|
|
108
|
+
x.block_until_ready()
|
|
109
|
+
|
|
110
|
+
jax.tree_map(_block, out)
|
|
88
111
|
end = time.perf_counter()
|
|
89
112
|
self.times.append((end - start) * 1e3)
|
|
90
113
|
return out
|
|
91
114
|
|
|
92
|
-
def _get_mean_times(self) ->
|
|
115
|
+
def _get_mean_times(self) -> Array:
|
|
93
116
|
if jax.device_count() == 1 or jax.process_count() == 1:
|
|
94
117
|
return np.array(self.times)
|
|
95
118
|
|
|
@@ -174,18 +197,10 @@ class Timer:
|
|
|
174
197
|
mean_time = np.mean(times_array)
|
|
175
198
|
std_time = np.std(times_array)
|
|
176
199
|
last_time = times_array[-1]
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
argument_size = self.profiling_data["argument_size"]
|
|
182
|
-
output_size = self.profiling_data["output_size"]
|
|
183
|
-
temp_size = self.profiling_data["temp_size"]
|
|
184
|
-
else:
|
|
185
|
-
generated_code = "N/A"
|
|
186
|
-
argument_size = "N/A"
|
|
187
|
-
output_size = "N/A"
|
|
188
|
-
temp_size = "N/A"
|
|
200
|
+
generated_code = self.profiling_data["generated_code"]
|
|
201
|
+
argument_size = self.profiling_data["argument_size"]
|
|
202
|
+
output_size = self.profiling_data["output_size"]
|
|
203
|
+
temp_size = self.profiling_data["temp_size"]
|
|
189
204
|
|
|
190
205
|
csv_line = (
|
|
191
206
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
@@ -249,7 +264,7 @@ class Timer:
|
|
|
249
264
|
headers=["Iteration", "Time"],
|
|
250
265
|
tablefmt="github",
|
|
251
266
|
))
|
|
252
|
-
if self.jax_fn:
|
|
267
|
+
if self.jax_fn and self.compile_info:
|
|
253
268
|
f.write("\n---\n")
|
|
254
269
|
f.write(f"## Compiled Code\n")
|
|
255
270
|
f.write(f"```hlo\n")
|
jax_hpc_profiler/utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.12
|
|
4
4
|
Summary: HPC Plotter and profiler for benchmarking data made for JAX
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -698,8 +698,7 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
-
|
|
702
|
-
Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
|
|
701
|
+
Dynamic: license-file
|
|
703
702
|
|
|
704
703
|
# JAX HPC Profiler
|
|
705
704
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=6DVpYuj908L01vk09-l7BLU6dW4OGgTepFR123fsawM,6314
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=dwOik2rJw5YV6ocQ-EE32iFOPlq2_3CHHAAuJJFt65Q,2286
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=R0mjUhV_Q-qi02mlxWiR241sxr58USBSykwFdjBa-oM,9484
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=4zc5HlJwepMK633BDz0iLTLWcLsvPdd6M1SL0-qs4js,10554
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=7i8qPfKogp8nGaGdyJ2-fbQomhIZqn73PQ14qldpFTc,14657
|
|
7
|
+
jax_hpc_profiler-0.2.12.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.12.dist-info/METADATA,sha256=bIriXXGdfJ7yTIdsuIdZgig3WpshBibBYKtGDjukr8A,49186
|
|
9
|
+
jax_hpc_profiler-0.2.12.dist-info/WHEEL,sha256=tTnHoFhvKQHCh4jz3yCn0WPTYIy7wXx3CJtJ7SJGV7c,91
|
|
10
|
+
jax_hpc_profiler-0.2.12.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.12.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.12.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=CSdl76LvaTVVn43dkwpVyiIkyl4lHlDCiI5jvUrIoj0,6059
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=2zPVTGRgFkYV75EJA1eoOqf92gCRXAtg-28cFgRy3Bw,2164
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=PdRdEIjsiiDbpr7iwvDUnC4mXz9QWE3JIvOAiWiSS3w,9382
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=DlbB4O4qJZImcjFq9T6zuywVbZT0x4dBxiJTAR0LNtY,9913
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=hyWldQIjlNs2laPxgc19szzPnJSHYnqiT9p3knPor8Y,14657
|
|
7
|
-
jax_hpc_profiler-0.2.10.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.10.dist-info/METADATA,sha256=Ca9EYifDf9zDD8nRMmLvxvBu-JUwsI3wIgqClT35egI,49271
|
|
9
|
-
jax_hpc_profiler-0.2.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
10
|
-
jax_hpc_profiler-0.2.10.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.10.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|