jax-hpc-profiler 0.2.9__tar.gz → 0.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/PKG-INFO +2 -2
  2. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/pyproject.toml +1 -1
  3. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/plotting.py +14 -9
  4. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/timer.py +17 -1
  5. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/utils.py +3 -2
  6. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/PKG-INFO +2 -2
  7. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/LICENSE +0 -0
  8. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/README.md +0 -0
  9. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/setup.cfg +0 -0
  10. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/__init__.py +0 -0
  11. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/create_argparse.py +0 -0
  12. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/main.py +0 -0
  13. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.9
3
+ Version: 0.2.10
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.9"
7
+ version = "0.2.10"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -77,9 +77,9 @@ def plot_scaling(
77
77
  output: Optional[str] = None,
78
78
  dark_bg: bool = False,
79
79
  print_decompositions: bool = False,
80
- backends: List[str] = ["NCCL"],
81
- precisions: List[str] = ["float32"],
82
- functions: List[str] | None = None,
80
+ backends: Optional[List[str]] = None,
81
+ precisions: Optional[List[str]] = None,
82
+ functions: Optional[List[str]] = None,
83
83
  plot_columns: List[str] = ["mean_time"],
84
84
  memory_units: str = "bytes",
85
85
  label_text: str = "plot",
@@ -141,6 +141,11 @@ def plot_scaling(
141
141
  by=[size_column])
142
142
  functions = (pd.unique(filtered_method_df["function"])
143
143
  if functions is None else functions)
144
+ precisions = (pd.unique(filtered_method_df["precision"])
145
+ if precisions is None else precisions)
146
+ backends = (pd.unique(filtered_method_df["backend"])
147
+ if backends is None else backends)
148
+
144
149
  combinations = product(backends, precisions, functions,
145
150
  plot_columns)
146
151
 
@@ -199,12 +204,12 @@ def plot_strong_scaling(
199
204
  csv_files: List[str],
200
205
  fixed_gpu_size: Optional[List[int]] = None,
201
206
  fixed_data_size: Optional[List[int]] = None,
202
- functions: List[str] | None = None,
203
- precisions: List[str] = ["float32"],
207
+ functions: Optional[List[str]] = None,
208
+ precisions: Optional[List[str]] = None,
204
209
  pdims: Optional[List[str]] = None,
205
210
  pdims_strategy: List[str] = ["plot_fastest"],
206
211
  print_decompositions: bool = False,
207
- backends: List[str] = ["NCCL"],
212
+ backends: Optional[List[str]] = None,
208
213
  plot_columns: List[str] = ["mean_time"],
209
214
  memory_units: str = "bytes",
210
215
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
@@ -256,12 +261,12 @@ def plot_weak_scaling(
256
261
  csv_files: List[str],
257
262
  fixed_gpu_size: Optional[List[int]] = None,
258
263
  fixed_data_size: Optional[List[int]] = None,
259
- functions: List[str] | None = None,
260
- precisions: List[str] = ["float32"],
264
+ functions: Optional[List[str]] = None,
265
+ precisions: Optional[List[str]] = None,
261
266
  pdims: Optional[List[str]] = None,
262
267
  pdims_strategy: List[str] = ["plot_fastest"],
263
268
  print_decompositions: bool = False,
264
- backends: List[str] = ["NCCL"],
269
+ backends: Optional[List[str]] = None,
265
270
  plot_columns: List[str] = ["mean_time"],
266
271
  memory_units: str = "bytes",
267
272
  label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import time
3
3
  from functools import partial
4
- from typing import Any, Callable, List, Tuple
4
+ from typing import Any, Callable, List, Optional, Tuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
@@ -133,8 +133,12 @@ class Timer:
133
133
  backend: str = "NCCL",
134
134
  nodes: int = 1,
135
135
  md_filename: str | None = None,
136
+ npz_data: Optional[dict] = None,
136
137
  extra_info: dict = {},
137
138
  ):
139
+ if self.jit_time == 0.0 and len(self.times) == 0:
140
+ print(f"No profiling data to report for {function}")
141
+ return
138
142
 
139
143
  if md_filename is None:
140
144
  dirname, filename = (
@@ -147,6 +151,18 @@ class Timer:
147
151
  f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
148
152
  )
149
153
 
154
+ if npz_data is not None:
155
+ dirname, filename = (
156
+ os.path.dirname(csv_filename),
157
+ os.path.splitext(os.path.basename(csv_filename))[0],
158
+ )
159
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
160
+ os.makedirs(report_folder, exist_ok=True)
161
+ npz_filename = (
162
+ f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
163
+ )
164
+ np.savez(npz_filename, **npz_data)
165
+
150
166
  y = x if y is None else y
151
167
  z = x if z is None else z
152
168
 
@@ -248,7 +248,7 @@ def clean_up_csv(
248
248
  data_sizes: Optional[List[int]] = None,
249
249
  pdims: Optional[List[str]] = None,
250
250
  pdims_strategy: List[str] = ['plot_fastest'],
251
- backends: List[str] = ['MPI', 'NCCL', 'MPI4JAX'],
251
+ backends: Optional[List[str]] = None,
252
252
  memory_units: str = 'KB',
253
253
  ) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
254
254
  """
@@ -331,7 +331,8 @@ def clean_up_csv(
331
331
  if function_names:
332
332
  df = df[df['function'].isin(function_names)]
333
333
  # Filter backends
334
- df = df[df['backend'].isin(backends)]
334
+ if backends:
335
+ df = df[df['backend'].isin(backends)]
335
336
 
336
337
  # Filter data sizes
337
338
  if data_sizes:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.9
3
+ Version: 0.2.10
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE