jax-hpc-profiler 0.2.10__tar.gz → 0.2.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/PKG-INFO +3 -4
  2. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/README.md +0 -2
  3. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/pyproject.toml +1 -1
  4. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/create_argparse.py +13 -0
  5. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/main.py +4 -0
  6. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/plotting.py +9 -5
  7. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/timer.py +46 -31
  8. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/utils.py +1 -1
  9. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/PKG-INFO +3 -4
  10. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/LICENSE +0 -0
  11. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/setup.cfg +0 -0
  12. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler/__init__.py +0 -0
  13. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.10 → jax_hpc_profiler-0.2.12}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -698,8 +698,7 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
-
702
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
701
+ Dynamic: license-file
703
702
 
704
703
  # JAX HPC Profiler
705
704
 
@@ -1,5 +1,3 @@
1
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
2
-
3
1
  # JAX HPC Profiler
4
2
 
5
3
  JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.10"
7
+ version = "0.2.12"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -168,6 +168,19 @@ def create_argparser():
168
168
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
169
169
  )
170
170
 
171
+ plot_parser.add_argument(
172
+ "-xl",
173
+ "--xlabel",
174
+ type=str,
175
+ help="X-axis label for the plot",
176
+ )
177
+ plot_parser.add_argument(
178
+ "-tl",
179
+ "--title",
180
+ type=str,
181
+ help="Title for the plot",
182
+ )
183
+
171
184
  subparsers.add_parser("label_help", help="Label customization help")
172
185
 
173
186
  args = parser.parse_args()
@@ -37,6 +37,8 @@ def main():
37
37
  args.plot_columns,
38
38
  args.memory_units,
39
39
  args.label_text,
40
+ args.title,
41
+ args.label_text,
40
42
  args.figure_size,
41
43
  args.dark_bg,
42
44
  args.output,
@@ -55,6 +57,8 @@ def main():
55
57
  args.plot_columns,
56
58
  args.memory_units,
57
59
  args.label_text,
60
+ args.title,
61
+ args.label_text,
58
62
  args.figure_size,
59
63
  args.dark_bg,
60
64
  args.output,
@@ -17,8 +17,8 @@ def configure_axes(
17
17
  ax: Axes,
18
18
  x_values: List[int],
19
19
  y_values: List[float],
20
- xlabel: str,
21
20
  title: str,
21
+ xlabel: str,
22
22
  plotting_memory: bool = False,
23
23
  memory_units: str = "bytes",
24
24
  ):
@@ -213,6 +213,8 @@ def plot_strong_scaling(
213
213
  plot_columns: List[str] = ["mean_time"],
214
214
  memory_units: str = "bytes",
215
215
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
216
+ xlabel: str = "Number of GPUs",
217
+ title: str = "Data sizes",
216
218
  figure_size: tuple = (6, 4),
217
219
  dark_bg: bool = False,
218
220
  output: Optional[str] = None,
@@ -241,8 +243,8 @@ def plot_strong_scaling(
241
243
  available_data_sizes,
242
244
  "gpus",
243
245
  "x",
244
- "Number of GPUs",
245
- "Data size",
246
+ xlabel,
247
+ title,
246
248
  figure_size,
247
249
  output,
248
250
  dark_bg,
@@ -270,6 +272,8 @@ def plot_weak_scaling(
270
272
  plot_columns: List[str] = ["mean_time"],
271
273
  memory_units: str = "bytes",
272
274
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
275
+ xlabel: str = "Data sizes",
276
+ title: str = "Number of GPUs",
273
277
  figure_size: tuple = (6, 4),
274
278
  dark_bg: bool = False,
275
279
  output: Optional[str] = None,
@@ -297,8 +301,8 @@ def plot_weak_scaling(
297
301
  available_gpu_counts,
298
302
  "x",
299
303
  "gpus",
300
- "Data size",
301
- "Number of GPUs",
304
+ xlabel,
305
+ title,
302
306
  figure_size,
303
307
  output,
304
308
  dark_bg,
@@ -11,23 +11,31 @@ from jax.experimental import mesh_utils
11
11
  from jax.experimental.shard_map import shard_map
12
12
  from jax.sharding import Mesh, NamedSharding
13
13
  from jax.sharding import PartitionSpec as P
14
+ from jaxtyping import Array
14
15
  from tabulate import tabulate
15
16
 
16
17
 
17
18
  class Timer:
18
19
 
19
- def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
20
+ def __init__(self,
21
+ save_jaxpr=False,
22
+ compile_info=True,
23
+ jax_fn=True,
24
+ devices=None,
25
+ static_argnums=()):
20
26
  self.jit_time = 0.0
21
27
  self.times = []
22
28
  self.profiling_data = {}
23
29
  self.compiled_code = {}
24
30
  self.save_jaxpr = save_jaxpr
31
+ self.compile_info = compile_info
25
32
  self.jax_fn = jax_fn
26
33
  self.devices = devices
34
+ self.static_argnums = static_argnums
27
35
 
28
36
  def _normalize_memory_units(self, memory_analysis) -> str:
29
37
 
30
- if not self.jax_fn:
38
+ if not (self.jax_fn and self.compile_info):
31
39
  return memory_analysis
32
40
 
33
41
  sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
@@ -47,23 +55,36 @@ class Timer:
47
55
  memory_analysis.temp_size_in_bytes,
48
56
  )
49
57
 
50
- def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
58
+ def chrono_jit(self, fun: Callable, *args, **kwargs) -> Array:
51
59
  start = time.perf_counter()
52
- out = fun(*args)
60
+ out = fun(*args, **kwargs)
53
61
  if self.jax_fn:
54
- if ndarray_arg is None:
55
- out.block_until_ready()
56
- else:
57
- out[ndarray_arg].block_until_ready()
62
+
63
+ def _block(x):
64
+ if isinstance(x, Array):
65
+ x.block_until_ready()
66
+
67
+ jax.tree_map(_block, out)
58
68
  end = time.perf_counter()
59
69
  self.jit_time = (end - start) * 1e3
60
70
 
71
+ self.compiled_code["JAXPR"] = "N/A"
72
+ self.compiled_code["LOWERED"] = "N/A"
73
+ self.compiled_code["COMPILED"] = "N/A"
74
+ self.profiling_data["generated_code"] = "N/A"
75
+ self.profiling_data["argument_size"] = "N/A"
76
+ self.profiling_data["output_size"] = "N/A"
77
+ self.profiling_data["temp_size"] = "N/A"
78
+
61
79
  if self.save_jaxpr:
62
- jaxpr = make_jaxpr(fun)(*args)
80
+ jaxpr = make_jaxpr(fun,
81
+ static_argnums=self.static_argnums)(*args,
82
+ **kwargs)
63
83
  self.compiled_code["JAXPR"] = jaxpr.pretty_print()
64
84
 
65
- if self.jax_fn:
66
- lowered = jax.jit(fun).lower(*args)
85
+ if self.jax_fn and self.compile_info:
86
+ lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
87
+ *args, **kwargs)
67
88
  compiled = lowered.compile()
68
89
  memory_analysis = self._read_memory_analysis(
69
90
  compiled.memory_analysis())
@@ -77,19 +98,21 @@ class Timer:
77
98
 
78
99
  return out
79
100
 
80
- def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
101
+ def chrono_fun(self, fun: Callable, *args, **kwargs) -> Array:
81
102
  start = time.perf_counter()
82
- out = fun(*args)
103
+ out = fun(*args, **kwargs)
83
104
  if self.jax_fn:
84
- if ndarray_arg is None:
85
- out.block_until_ready()
86
- else:
87
- out[ndarray_arg].block_until_ready()
105
+
106
+ def _block(x):
107
+ if isinstance(x, Array):
108
+ x.block_until_ready()
109
+
110
+ jax.tree_map(_block, out)
88
111
  end = time.perf_counter()
89
112
  self.times.append((end - start) * 1e3)
90
113
  return out
91
114
 
92
- def _get_mean_times(self) -> np.ndarray:
115
+ def _get_mean_times(self) -> Array:
93
116
  if jax.device_count() == 1 or jax.process_count() == 1:
94
117
  return np.array(self.times)
95
118
 
@@ -174,18 +197,10 @@ class Timer:
174
197
  mean_time = np.mean(times_array)
175
198
  std_time = np.std(times_array)
176
199
  last_time = times_array[-1]
177
-
178
- if self.jax_fn:
179
-
180
- generated_code = self.profiling_data["generated_code"]
181
- argument_size = self.profiling_data["argument_size"]
182
- output_size = self.profiling_data["output_size"]
183
- temp_size = self.profiling_data["temp_size"]
184
- else:
185
- generated_code = "N/A"
186
- argument_size = "N/A"
187
- output_size = "N/A"
188
- temp_size = "N/A"
200
+ generated_code = self.profiling_data["generated_code"]
201
+ argument_size = self.profiling_data["argument_size"]
202
+ output_size = self.profiling_data["output_size"]
203
+ temp_size = self.profiling_data["temp_size"]
189
204
 
190
205
  csv_line = (
191
206
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
@@ -249,7 +264,7 @@ class Timer:
249
264
  headers=["Iteration", "Time"],
250
265
  tablefmt="github",
251
266
  ))
252
- if self.jax_fn:
267
+ if self.jax_fn and self.compile_info:
253
268
  f.write("\n---\n")
254
269
  f.write(f"## Compiled Code\n")
255
270
  f.write(f"```hlo\n")
@@ -292,7 +292,7 @@ def clean_up_csv(
292
292
 
293
293
  df = pd.read_csv(csv_file,
294
294
  header=None,
295
- skiprows=1,
295
+ skiprows=0,
296
296
  names=[
297
297
  "function", "precision", "x", "y", "z", "px",
298
298
  "py", "backend", "nodes", "jit_time", "min_time",
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -698,8 +698,7 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
-
702
- Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
701
+ Dynamic: license-file
703
702
 
704
703
  # JAX HPC Profiler
705
704