jax-hpc-profiler 0.2.9__tar.gz → 0.2.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/PKG-INFO +3 -4
  2. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/README.md +0 -2
  3. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/pyproject.toml +1 -1
  4. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/create_argparse.py +13 -0
  5. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/main.py +4 -0
  6. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/plotting.py +23 -14
  7. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/timer.py +63 -32
  8. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/utils.py +4 -3
  9. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/PKG-INFO +3 -4
  10. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/LICENSE +0 -0
  11. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/setup.cfg +0 -0
  12. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler/__init__.py +0 -0
  13. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.11}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -698,8 +698,7 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
-
702
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
701
+ Dynamic: license-file
703
702
 
704
703
  # JAX HPC Profiler
705
704
 
@@ -1,5 +1,3 @@
1
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
2
-
3
1
  # JAX HPC Profiler
4
2
 
5
3
  JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.9"
7
+ version = "0.2.11"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -168,6 +168,19 @@ def create_argparser():
168
168
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
169
169
  )
170
170
 
171
+ plot_parser.add_argument(
172
+ "-xl",
173
+ "--xlabel",
174
+ type=str,
175
+ help="X-axis label for the plot",
176
+ )
177
+ plot_parser.add_argument(
178
+ "-tl",
179
+ "--title",
180
+ type=str,
181
+ help="Title for the plot",
182
+ )
183
+
171
184
  subparsers.add_parser("label_help", help="Label customization help")
172
185
 
173
186
  args = parser.parse_args()
@@ -37,6 +37,8 @@ def main():
37
37
  args.plot_columns,
38
38
  args.memory_units,
39
39
  args.label_text,
40
+ args.title,
41
+ args.label_text,
40
42
  args.figure_size,
41
43
  args.dark_bg,
42
44
  args.output,
@@ -55,6 +57,8 @@ def main():
55
57
  args.plot_columns,
56
58
  args.memory_units,
57
59
  args.label_text,
60
+ args.title,
61
+ args.label_text,
58
62
  args.figure_size,
59
63
  args.dark_bg,
60
64
  args.output,
@@ -17,8 +17,8 @@ def configure_axes(
17
17
  ax: Axes,
18
18
  x_values: List[int],
19
19
  y_values: List[float],
20
- xlabel: str,
21
20
  title: str,
21
+ xlabel: str,
22
22
  plotting_memory: bool = False,
23
23
  memory_units: str = "bytes",
24
24
  ):
@@ -77,9 +77,9 @@ def plot_scaling(
77
77
  output: Optional[str] = None,
78
78
  dark_bg: bool = False,
79
79
  print_decompositions: bool = False,
80
- backends: List[str] = ["NCCL"],
81
- precisions: List[str] = ["float32"],
82
- functions: List[str] | None = None,
80
+ backends: Optional[List[str]] = None,
81
+ precisions: Optional[List[str]] = None,
82
+ functions: Optional[List[str]] = None,
83
83
  plot_columns: List[str] = ["mean_time"],
84
84
  memory_units: str = "bytes",
85
85
  label_text: str = "plot",
@@ -141,6 +141,11 @@ def plot_scaling(
141
141
  by=[size_column])
142
142
  functions = (pd.unique(filtered_method_df["function"])
143
143
  if functions is None else functions)
144
+ precisions = (pd.unique(filtered_method_df["precision"])
145
+ if precisions is None else precisions)
146
+ backends = (pd.unique(filtered_method_df["backend"])
147
+ if backends is None else backends)
148
+
144
149
  combinations = product(backends, precisions, functions,
145
150
  plot_columns)
146
151
 
@@ -199,15 +204,17 @@ def plot_strong_scaling(
199
204
  csv_files: List[str],
200
205
  fixed_gpu_size: Optional[List[int]] = None,
201
206
  fixed_data_size: Optional[List[int]] = None,
202
- functions: List[str] | None = None,
203
- precisions: List[str] = ["float32"],
207
+ functions: Optional[List[str]] = None,
208
+ precisions: Optional[List[str]] = None,
204
209
  pdims: Optional[List[str]] = None,
205
210
  pdims_strategy: List[str] = ["plot_fastest"],
206
211
  print_decompositions: bool = False,
207
- backends: List[str] = ["NCCL"],
212
+ backends: Optional[List[str]] = None,
208
213
  plot_columns: List[str] = ["mean_time"],
209
214
  memory_units: str = "bytes",
210
215
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
216
+ xlabel: str = "Number of GPUs",
217
+ title: str = "Data sizes",
211
218
  figure_size: tuple = (6, 4),
212
219
  dark_bg: bool = False,
213
220
  output: Optional[str] = None,
@@ -236,8 +243,8 @@ def plot_strong_scaling(
236
243
  available_data_sizes,
237
244
  "gpus",
238
245
  "x",
239
- "Number of GPUs",
240
- "Data size",
246
+ xlabel,
247
+ title,
241
248
  figure_size,
242
249
  output,
243
250
  dark_bg,
@@ -256,15 +263,17 @@ def plot_weak_scaling(
256
263
  csv_files: List[str],
257
264
  fixed_gpu_size: Optional[List[int]] = None,
258
265
  fixed_data_size: Optional[List[int]] = None,
259
- functions: List[str] | None = None,
260
- precisions: List[str] = ["float32"],
266
+ functions: Optional[List[str]] = None,
267
+ precisions: Optional[List[str]] = None,
261
268
  pdims: Optional[List[str]] = None,
262
269
  pdims_strategy: List[str] = ["plot_fastest"],
263
270
  print_decompositions: bool = False,
264
- backends: List[str] = ["NCCL"],
271
+ backends: Optional[List[str]] = None,
265
272
  plot_columns: List[str] = ["mean_time"],
266
273
  memory_units: str = "bytes",
267
274
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
275
+ xlabel: str = "Data sizes",
276
+ title: str = "Number of GPUs",
268
277
  figure_size: tuple = (6, 4),
269
278
  dark_bg: bool = False,
270
279
  output: Optional[str] = None,
@@ -292,8 +301,8 @@ def plot_weak_scaling(
292
301
  available_gpu_counts,
293
302
  "x",
294
303
  "gpus",
295
- "Data size",
296
- "Number of GPUs",
304
+ xlabel,
305
+ title,
297
306
  figure_size,
298
307
  output,
299
308
  dark_bg,
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import time
3
3
  from functools import partial
4
- from typing import Any, Callable, List, Tuple
4
+ from typing import Any, Callable, List, Optional, Tuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
@@ -11,23 +11,31 @@ from jax.experimental import mesh_utils
11
11
  from jax.experimental.shard_map import shard_map
12
12
  from jax.sharding import Mesh, NamedSharding
13
13
  from jax.sharding import PartitionSpec as P
14
+ from jaxtyping import Array
14
15
  from tabulate import tabulate
15
16
 
16
17
 
17
18
  class Timer:
18
19
 
19
- def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
20
+ def __init__(self,
21
+ save_jaxpr=False,
22
+ compile_info=True,
23
+ jax_fn=True,
24
+ devices=None,
25
+ static_argnums=()):
20
26
  self.jit_time = 0.0
21
27
  self.times = []
22
28
  self.profiling_data = {}
23
29
  self.compiled_code = {}
24
30
  self.save_jaxpr = save_jaxpr
31
+ self.compile_info = compile_info
25
32
  self.jax_fn = jax_fn
26
33
  self.devices = devices
34
+ self.static_argnums = static_argnums
27
35
 
28
36
  def _normalize_memory_units(self, memory_analysis) -> str:
29
37
 
30
- if not self.jax_fn:
38
+ if not (self.jax_fn and self.compile_info):
31
39
  return memory_analysis
32
40
 
33
41
  sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
@@ -47,23 +55,36 @@ class Timer:
47
55
  memory_analysis.temp_size_in_bytes,
48
56
  )
49
57
 
50
- def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
58
+ def chrono_jit(self, fun: Callable, *args, **kwargs) -> Array:
51
59
  start = time.perf_counter()
52
- out = fun(*args)
60
+ out = fun(*args, **kwargs)
53
61
  if self.jax_fn:
54
- if ndarray_arg is None:
55
- out.block_until_ready()
56
- else:
57
- out[ndarray_arg].block_until_ready()
62
+
63
+ def _block(x):
64
+ if isinstance(x, Array):
65
+ x.block_until_ready()
66
+
67
+ jax.tree_map(_block, out)
58
68
  end = time.perf_counter()
59
69
  self.jit_time = (end - start) * 1e3
60
70
 
71
+ self.compiled_code["JAXPR"] = "N/A"
72
+ self.compiled_code["LOWERED"] = "N/A"
73
+ self.compiled_code["COMPILED"] = "N/A"
74
+ self.profiling_data["generated_code"] = "N/A"
75
+ self.profiling_data["argument_size"] = "N/A"
76
+ self.profiling_data["output_size"] = "N/A"
77
+ self.profiling_data["temp_size"] = "N/A"
78
+
61
79
  if self.save_jaxpr:
62
- jaxpr = make_jaxpr(fun)(*args)
80
+ jaxpr = make_jaxpr(fun,
81
+ static_argnums=self.static_argnums)(*args,
82
+ **kwargs)
63
83
  self.compiled_code["JAXPR"] = jaxpr.pretty_print()
64
84
 
65
- if self.jax_fn:
66
- lowered = jax.jit(fun).lower(*args)
85
+ if self.jax_fn and self.compile_info:
86
+ lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
87
+ *args, **kwargs)
67
88
  compiled = lowered.compile()
68
89
  memory_analysis = self._read_memory_analysis(
69
90
  compiled.memory_analysis())
@@ -77,19 +98,21 @@ class Timer:
77
98
 
78
99
  return out
79
100
 
80
- def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
101
+ def chrono_fun(self, fun: Callable, *args, **kwargs) -> Array:
81
102
  start = time.perf_counter()
82
- out = fun(*args)
103
+ out = fun(*args, **kwargs)
83
104
  if self.jax_fn:
84
- if ndarray_arg is None:
85
- out.block_until_ready()
86
- else:
87
- out[ndarray_arg].block_until_ready()
105
+
106
+ def _block(x):
107
+ if isinstance(x, Array):
108
+ x.block_until_ready()
109
+
110
+ jax.tree_map(_block, out)
88
111
  end = time.perf_counter()
89
112
  self.times.append((end - start) * 1e3)
90
113
  return out
91
114
 
92
- def _get_mean_times(self) -> np.ndarray:
115
+ def _get_mean_times(self) -> Array:
93
116
  if jax.device_count() == 1 or jax.process_count() == 1:
94
117
  return np.array(self.times)
95
118
 
@@ -133,8 +156,12 @@ class Timer:
133
156
  backend: str = "NCCL",
134
157
  nodes: int = 1,
135
158
  md_filename: str | None = None,
159
+ npz_data: Optional[dict] = None,
136
160
  extra_info: dict = {},
137
161
  ):
162
+ if self.jit_time == 0.0 and len(self.times) == 0:
163
+ print(f"No profiling data to report for {function}")
164
+ return
138
165
 
139
166
  if md_filename is None:
140
167
  dirname, filename = (
@@ -147,6 +174,18 @@ class Timer:
147
174
  f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
148
175
  )
149
176
 
177
+ if npz_data is not None:
178
+ dirname, filename = (
179
+ os.path.dirname(csv_filename),
180
+ os.path.splitext(os.path.basename(csv_filename))[0],
181
+ )
182
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
183
+ os.makedirs(report_folder, exist_ok=True)
184
+ npz_filename = (
185
+ f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
186
+ )
187
+ np.savez(npz_filename, **npz_data)
188
+
150
189
  y = x if y is None else y
151
190
  z = x if z is None else z
152
191
 
@@ -158,18 +197,10 @@ class Timer:
158
197
  mean_time = np.mean(times_array)
159
198
  std_time = np.std(times_array)
160
199
  last_time = times_array[-1]
161
-
162
- if self.jax_fn:
163
-
164
- generated_code = self.profiling_data["generated_code"]
165
- argument_size = self.profiling_data["argument_size"]
166
- output_size = self.profiling_data["output_size"]
167
- temp_size = self.profiling_data["temp_size"]
168
- else:
169
- generated_code = "N/A"
170
- argument_size = "N/A"
171
- output_size = "N/A"
172
- temp_size = "N/A"
200
+ generated_code = self.profiling_data["generated_code"]
201
+ argument_size = self.profiling_data["argument_size"]
202
+ output_size = self.profiling_data["output_size"]
203
+ temp_size = self.profiling_data["temp_size"]
173
204
 
174
205
  csv_line = (
175
206
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
@@ -233,7 +264,7 @@ class Timer:
233
264
  headers=["Iteration", "Time"],
234
265
  tablefmt="github",
235
266
  ))
236
- if self.jax_fn:
267
+ if self.jax_fn and self.compile_info:
237
268
  f.write("\n---\n")
238
269
  f.write(f"## Compiled Code\n")
239
270
  f.write(f"```hlo\n")
@@ -248,7 +248,7 @@ def clean_up_csv(
248
248
  data_sizes: Optional[List[int]] = None,
249
249
  pdims: Optional[List[str]] = None,
250
250
  pdims_strategy: List[str] = ['plot_fastest'],
251
- backends: List[str] = ['MPI', 'NCCL', 'MPI4JAX'],
251
+ backends: Optional[List[str]] = None,
252
252
  memory_units: str = 'KB',
253
253
  ) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
254
254
  """
@@ -292,7 +292,7 @@ def clean_up_csv(
292
292
 
293
293
  df = pd.read_csv(csv_file,
294
294
  header=None,
295
- skiprows=1,
295
+ skiprows=0,
296
296
  names=[
297
297
  "function", "precision", "x", "y", "z", "px",
298
298
  "py", "backend", "nodes", "jit_time", "min_time",
@@ -331,7 +331,8 @@ def clean_up_csv(
331
331
  if function_names:
332
332
  df = df[df['function'].isin(function_names)]
333
333
  # Filter backends
334
- df = df[df['backend'].isin(backends)]
334
+ if backends:
335
+ df = df[df['backend'].isin(backends)]
335
336
 
336
337
  # Filter data sizes
337
338
  if data_sizes:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -698,8 +698,7 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
-
702
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
701
+ Dynamic: license-file
703
702
 
704
703
  # JAX HPC Profiler
705
704