jax-hpc-profiler 0.2.13__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/PKG-INFO +36 -4
  2. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/README.md +27 -1
  3. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/pyproject.toml +22 -4
  4. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/__init__.py +2 -1
  5. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/create_argparse.py +21 -2
  6. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/main.py +28 -5
  7. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/plotting.py +191 -1
  8. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/timer.py +12 -8
  9. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/PKG-INFO +36 -4
  10. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/SOURCES.txt +3 -1
  11. jax_hpc_profiler-0.3.0/src/jax_hpc_profiler.egg-info/requires.txt +12 -0
  12. jax_hpc_profiler-0.3.0/tests/test_plotting.py +112 -0
  13. jax_hpc_profiler-0.3.0/tests/test_timer.py +103 -0
  14. jax_hpc_profiler-0.2.13/src/jax_hpc_profiler.egg-info/requires.txt +0 -5
  15. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/LICENSE +0 -0
  16. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/setup.cfg +0 -0
  17. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler/utils.py +0 -0
  18. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  19. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  20. {jax_hpc_profiler-0.2.13 → jax_hpc_profiler-0.3.0}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.13
4
- Summary: HPC Plotter and profiler for benchmarking data made for JAX
3
+ Version: 0.3.0
4
+ Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
7
7
  Version 3, 29 June 2007
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
679
679
  <https://www.gnu.org/licenses/why-not-lgpl.html>.
680
680
 
681
681
  Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
682
- Keywords: jax,hpc,profiler,plotter,benchmarking
682
+ Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
683
683
  Classifier: Development Status :: 4 - Beta
684
684
  Classifier: Intended Audience :: Developers
685
685
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
@@ -698,10 +698,22 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
+ Requires-Dist: adjustText
702
+ Requires-Dist: jax>=0.4.0
703
+ Requires-Dist: jaxtyping
704
+ Provides-Extra: test
705
+ Requires-Dist: pytest; extra == "test"
706
+ Requires-Dist: pytest-cov; extra == "test"
701
707
  Dynamic: license-file
702
708
 
703
709
  # JAX HPC Profiler
704
710
 
711
+ [![Build](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
712
+ [![Code Formatting](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
713
+ [![Tests](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
714
+ [![Notebooks](https://img.shields.io/github/actions/workflow/status/ASKabalan/jax-hpc-profiler/notebooks.yml?logo=jupyter&label=notebooks)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
715
+ [![GPLv3 License](https://img.shields.io/badge/License-GPL%20v3-yellow.svg)](https://www.gnu.org/licenses/gpl-3.0)
716
+
705
717
  JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
706
718
 
707
719
  ## Table of Contents
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
883
895
  - `-db, --dark_bg`: Use dark background for plotting.
884
896
  - `-pd, --print_decompositions`: Print decompositions on plot (experimental).
885
897
  - `-b, --backends`: List of backends to include. This argument can be multiple ones.
886
- - `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
898
+ - `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
899
+ - `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
900
+ - `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
901
+ - `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
902
+ - `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
887
903
  - `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
888
904
 
905
+ ### Weak scaling CLI example
906
+
907
+ For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
908
+
909
+ ```bash
910
+ jax-hpc-profiler plot \
911
+ -f MYDATA.csv \
912
+ -pt mean_time \
913
+ -sc Weak \
914
+ -g 1 2 4 8 \
915
+ -d 32 64 128 256 \
916
+ --weak_ideal_line
917
+ ```
918
+
919
+ This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
920
+
889
921
  ## Examples
890
922
 
891
923
  The repository includes examples for both profiling and plotting.
@@ -1,5 +1,11 @@
1
1
  # JAX HPC Profiler
2
2
 
3
+ [![Build](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
4
+ [![Code Formatting](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
5
+ [![Tests](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
6
+ [![Notebooks](https://img.shields.io/github/actions/workflow/status/ASKabalan/jax-hpc-profiler/notebooks.yml?logo=jupyter&label=notebooks)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
7
+ [![GPLv3 License](https://img.shields.io/badge/License-GPL%20v3-yellow.svg)](https://www.gnu.org/licenses/gpl-3.0)
8
+
3
9
  JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
4
10
 
5
11
  ## Table of Contents
@@ -181,9 +187,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
181
187
  - `-db, --dark_bg`: Use dark background for plotting.
182
188
  - `-pd, --print_decompositions`: Print decompositions on plot (experimental).
183
189
  - `-b, --backends`: List of backends to include. This argument can be multiple ones.
184
- - `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
190
+ - `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
191
+ - `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
192
+ - `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
193
+ - `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
194
+ - `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
185
195
  - `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
186
196
 
197
+ ### Weak scaling CLI example
198
+
199
+ For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
200
+
201
+ ```bash
202
+ jax-hpc-profiler plot \
203
+ -f MYDATA.csv \
204
+ -pt mean_time \
205
+ -sc Weak \
206
+ -g 1 2 4 8 \
207
+ -d 32 64 128 256 \
208
+ --weak_ideal_line
209
+ ```
210
+
211
+ This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
212
+
187
213
  ## Examples
188
214
 
189
215
  The repository includes examples for both profiling and plotting.
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.13"
8
- description = "HPC Plotter and profiler for benchmarking data made for JAX"
7
+ version = "0.3.0"
8
+ description = "A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization."
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
11
11
  ]
@@ -14,12 +14,15 @@ dependencies = [
14
14
  "pandas",
15
15
  "matplotlib",
16
16
  "seaborn",
17
- "tabulate"
17
+ "tabulate",
18
+ "adjustText",
19
+ "jax>=0.4.0",
20
+ "jaxtyping",
18
21
  ]
19
22
  readme = "README.md"
20
23
  license = { file = "LICENSE" }
21
24
  requires-python = ">=3.8"
22
- keywords = ["jax", "hpc", "profiler", "plotter", "benchmarking"]
25
+ keywords = ["jax", "hpc", "profiling", "benchmarking", "visualization", "scaling", "performance-analysis", "gpu", "distributed-computing"]
23
26
 
24
27
  # For a list of valid classifiers, see https://pypi.org/classifiers/
25
28
  classifiers = [
@@ -45,6 +48,12 @@ classifiers = [
45
48
 
46
49
  urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
47
50
 
51
+ [project.optional-dependencies]
52
+ test = [
53
+ "pytest",
54
+ "pytest-cov",
55
+ ]
56
+
48
57
  [project.scripts]
49
58
  jhp = "jax_hpc_profiler.main:main"
50
59
 
@@ -70,7 +79,16 @@ ignore = [
70
79
  'E731',
71
80
  'E741',
72
81
  'F722', # conflicts with jaxtyping Array annotations
82
+ "E402", # module level import not at top of file
73
83
  ]
74
84
 
85
+ [tool.ruff.lint.per-file-ignores]
86
+ "*.ipynb" = ["F401"]
87
+
75
88
  [tool.ruff.format]
76
89
  quote-style = 'single'
90
+
91
+ [tool.pytest.ini_options]
92
+ addopts = "--cov=src --cov-report=term-missing"
93
+ testpaths = ["tests"]
94
+ python_files = "test_*.py"
@@ -1,5 +1,5 @@
1
1
  from .create_argparse import create_argparser
2
- from .plotting import plot_strong_scaling, plot_weak_scaling
2
+ from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
3
3
  from .timer import Timer
4
4
  from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
5
5
 
@@ -7,6 +7,7 @@ __all__ = [
7
7
  'create_argparser',
8
8
  'plot_strong_scaling',
9
9
  'plot_weak_scaling',
10
+ 'plot_weak_fixed_scaling',
10
11
  'Timer',
11
12
  'clean_up_csv',
12
13
  'concatenate_csvs',
@@ -135,9 +135,24 @@ def create_argparser():
135
135
  plot_parser.add_argument(
136
136
  '-sc',
137
137
  '--scaling',
138
- choices=['Weak', 'Strong', 'w', 's'],
138
+ choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
139
139
  required=True,
140
- help='Scaling type (Weak or Strong)',
140
+ help='Scaling type (Strong, Weak, or WeakFixed)',
141
+ )
142
+
143
+ # Weak-scaling specific options
144
+ plot_parser.add_argument(
145
+ '--weak_ideal_line',
146
+ action='store_true',
147
+ help='Overlay an ideal flat line for weak scaling (Weak mode only)',
148
+ )
149
+ plot_parser.add_argument(
150
+ '--weak_reverse_axes',
151
+ action='store_true',
152
+ help=(
153
+ 'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
154
+ 'of data size. Requires --gpus and --data_size with equal lengths.'
155
+ ),
141
156
  )
142
157
 
143
158
  # Label customization argument
@@ -196,4 +211,8 @@ def create_argparser():
196
211
  else:
197
212
  raise ValueError('Either plot_times or plot_memory should be provided')
198
213
 
214
+ # Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
215
+ # data_size are provided and have matching lengths. For Strong and
216
+ # WeakFixed, gpus/data_size remain optional as before.
217
+
199
218
  return args
@@ -1,5 +1,5 @@
1
1
  from .create_argparse import create_argparser
2
- from .plotting import plot_strong_scaling, plot_weak_scaling
2
+ from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
3
3
  from .utils import concatenate_csvs
4
4
 
5
5
 
@@ -19,7 +19,8 @@ def main():
19
19
  print(' -- %p% or %pdims%: pdims')
20
20
  print(' -- %n% or %node%: node')
21
21
  elif args.command == 'plot':
22
- if args.scaling.lower() == 'weak' or args.scaling.lower() == 'w':
22
+ scaling = args.scaling.lower()
23
+ if scaling in ('weak', 'w'):
23
24
  plot_weak_scaling(
24
25
  args.csv_files,
25
26
  args.gpus,
@@ -33,13 +34,15 @@ def main():
33
34
  args.plot_columns,
34
35
  args.memory_units,
35
36
  args.label_text,
37
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
36
38
  args.title,
37
- args.label_text,
38
39
  args.figure_size,
39
40
  args.dark_bg,
40
41
  args.output,
42
+ args.weak_ideal_line,
43
+ args.weak_reverse_axes,
41
44
  )
42
- elif args.scaling.lower() == 'strong' or args.scaling.lower() == 's':
45
+ elif scaling in ('strong', 's'):
43
46
  plot_strong_scaling(
44
47
  args.csv_files,
45
48
  args.gpus,
@@ -53,8 +56,28 @@ def main():
53
56
  args.plot_columns,
54
57
  args.memory_units,
55
58
  args.label_text,
56
- args.title,
59
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
60
+ args.title if getattr(args, 'title', None) is not None else 'Data sizes',
61
+ args.figure_size,
62
+ args.dark_bg,
63
+ args.output,
64
+ )
65
+ elif scaling in ('weakfixed', 'wf'):
66
+ plot_weak_fixed_scaling(
67
+ args.csv_files,
68
+ args.gpus,
69
+ args.data_size,
70
+ args.function_name,
71
+ args.precision,
72
+ args.filter_pdims,
73
+ args.pdim_strategy,
74
+ args.print_decompositions,
75
+ args.backends,
76
+ args.plot_columns,
77
+ args.memory_units,
57
78
  args.label_text,
79
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
80
+ args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
58
81
  args.figure_size,
59
82
  args.dark_bg,
60
83
  args.output,
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import pandas as pd
7
+ from adjustText import adjust_text
7
8
  from matplotlib.axes import Axes
8
9
  from matplotlib.patches import FancyBboxPatch
9
10
 
@@ -252,6 +253,195 @@ def plot_strong_scaling(
252
253
 
253
254
 
254
255
  def plot_weak_scaling(
256
+ csv_files: List[str],
257
+ fixed_gpu_size: Optional[List[int]] = None,
258
+ fixed_data_size: Optional[List[int]] = None,
259
+ functions: Optional[List[str]] = None,
260
+ precisions: Optional[List[str]] = None,
261
+ pdims: Optional[List[str]] = None,
262
+ pdims_strategy: List[str] = ['plot_fastest'],
263
+ print_decompositions: bool = False,
264
+ backends: Optional[List[str]] = None,
265
+ plot_columns: List[str] = ['mean_time'],
266
+ memory_units: str = 'bytes',
267
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
268
+ xlabel: str = 'Number of GPUs',
269
+ title: Optional[str] = None,
270
+ figure_size: tuple = (6, 4),
271
+ dark_bg: bool = False,
272
+ output: Optional[str] = None,
273
+ ideal_line: bool = False,
274
+ reverse_axes: bool = False,
275
+ ):
276
+ """
277
+ Plot true weak scaling: runtime vs GPUs for explicit (gpus, data size) sequences.
278
+
279
+ Both ``fixed_gpu_size`` and ``fixed_data_size`` must be provided and have the same length,
280
+ representing explicit weak-scaling pairs (gpus[i], data_size[i]).
281
+
282
+ reverse_axes:
283
+ - False (default): x-axis is GPUs, y-axis is time; points are annotated with
284
+ ``N=<data_size>``.
285
+ - True: x-axis is data size, y-axis is time; points are annotated with ``GPUs=<gpu_count>``.
286
+ """
287
+ if fixed_gpu_size is None or fixed_data_size is None:
288
+ raise ValueError(
289
+ 'Weak scaling requires both fixed_gpu_size (gpus) and fixed_data_size (problem sizes).'
290
+ )
291
+ if len(fixed_gpu_size) != len(fixed_data_size):
292
+ raise ValueError(
293
+ 'Weak scaling requires fixed_gpu_size and fixed_data_size lists of equal length.'
294
+ )
295
+
296
+ gpu_to_data = {int(g): int(d) for g, d in zip(fixed_gpu_size, fixed_data_size)}
297
+ data_to_gpu = {int(d): int(g) for g, d in zip(fixed_gpu_size, fixed_data_size)}
298
+ x_col = 'x' if reverse_axes else 'gpus'
299
+
300
+ dataframes, _, _ = clean_up_csv(
301
+ csv_files,
302
+ precisions,
303
+ functions,
304
+ fixed_gpu_size,
305
+ fixed_data_size,
306
+ pdims,
307
+ pdims_strategy,
308
+ backends,
309
+ memory_units,
310
+ )
311
+ if len(dataframes) == 0:
312
+ print('No dataframes found for the given arguments. Exiting...')
313
+ return
314
+
315
+ if dark_bg:
316
+ plt.style.use('dark_background')
317
+
318
+ fig, ax = plt.subplots(figsize=figure_size)
319
+
320
+ x_values: List[float] = []
321
+ y_values: List[float] = []
322
+ annotations: List = []
323
+ ideal_line_plotted = False
324
+
325
+ for method, df in dataframes.items():
326
+ # Determine parameter sets from the filtered dataframe if not provided
327
+ local_functions = pd.unique(df['function']) if functions is None else functions
328
+ local_precisions = pd.unique(df['precision']) if precisions is None else precisions
329
+ local_backends = pd.unique(df['backend']) if backends is None else backends
330
+
331
+ combinations = product(local_backends, local_precisions, local_functions, plot_columns)
332
+
333
+ for backend, precision, function, plot_column in combinations:
334
+ base_df = df[
335
+ (df['backend'] == backend)
336
+ & (df['precision'] == precision)
337
+ & (df['function'] == function)
338
+ ]
339
+ if base_df.empty:
340
+ continue
341
+
342
+ # Keep only rows matching any of the (gpus, x) pairs
343
+ mask = pd.Series(False, index=base_df.index)
344
+ for g, d in zip(fixed_gpu_size, fixed_data_size):
345
+ mask |= (base_df['gpus'] == int(g)) & (base_df['x'] == int(d))
346
+
347
+ filtered_params_df = base_df[mask]
348
+ if filtered_params_df.empty:
349
+ continue
350
+
351
+ x_vals, y_vals = plot_with_pdims_strategy(
352
+ ax,
353
+ filtered_params_df,
354
+ method,
355
+ pdims_strategy,
356
+ print_decompositions,
357
+ x_col,
358
+ plot_column,
359
+ label_text,
360
+ )
361
+ if x_vals is None or len(x_vals) == 0:
362
+ continue
363
+
364
+ x_arr = np.asarray(x_vals).reshape(-1)
365
+ y_arr = np.asarray(y_vals).reshape(-1)
366
+
367
+ # Annotate every point with data size or GPU count depending on axis choice.
368
+ # Use plain data coordinates for the text; adjust_text will then only move
369
+ # the labels slightly (mostly vertically) to avoid overlap.
370
+ for xv, yv in zip(x_arr, y_arr):
371
+ if reverse_axes:
372
+ gpu = data_to_gpu.get(int(xv))
373
+ if gpu is None:
374
+ continue
375
+ label = f'GPUs={gpu}'
376
+ else:
377
+ data_size = gpu_to_data.get(int(xv))
378
+ if data_size is None:
379
+ continue
380
+ label = f'N={data_size}'
381
+
382
+ text_obj = ax.text(
383
+ float(xv),
384
+ float(yv),
385
+ label,
386
+ ha='center',
387
+ va='bottom',
388
+ fontsize='small',
389
+ clip_on=True,
390
+ )
391
+ annotations.append(text_obj)
392
+
393
+ x_values.extend(x_arr.tolist())
394
+ y_values.extend(y_arr.tolist())
395
+
396
+ if ideal_line and not ideal_line_plotted:
397
+ # Use the smallest x value in this curve as baseline
398
+ baseline_index = np.argmin(x_arr)
399
+ baseline_y = y_arr[baseline_index]
400
+ ax.hlines(
401
+ baseline_y,
402
+ xmin=float(np.min(x_arr)),
403
+ xmax=float(np.max(x_arr)),
404
+ colors='gray',
405
+ linestyles='dashed',
406
+ label='Ideal weak scaling',
407
+ )
408
+ ideal_line_plotted = True
409
+ y_values.append(float(baseline_y))
410
+
411
+ if x_values:
412
+ plotting_memory = 'time' not in plot_columns[0].lower()
413
+ figure_title = title if title is not None else 'Weak scaling'
414
+ configure_axes(
415
+ ax,
416
+ x_values,
417
+ y_values,
418
+ figure_title,
419
+ xlabel,
420
+ plotting_memory,
421
+ memory_units,
422
+ )
423
+ if annotations:
424
+ ax.figure.canvas.draw()
425
+ adjust_text(
426
+ annotations,
427
+ ax=ax,
428
+ # keep points aligned in x, only allow vertical motion
429
+ only_move={'text': 'y', 'static': 'y'},
430
+ expand=(1.02, 1.05),
431
+ force_text=(0.08, 0.2),
432
+ max_move=(0, 30),
433
+ )
434
+
435
+ fig.tight_layout()
436
+ rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
437
+ fig.patches.append(rect)
438
+ if output is None:
439
+ plt.show()
440
+ else:
441
+ plt.savefig(output)
442
+
443
+
444
+ def plot_weak_fixed_scaling(
255
445
  csv_files: List[str],
256
446
  fixed_gpu_size: Optional[List[int]] = None,
257
447
  fixed_data_size: Optional[List[int]] = None,
@@ -271,7 +461,7 @@ def plot_weak_scaling(
271
461
  output: Optional[str] = None,
272
462
  ):
273
463
  """
274
- Plot weak scaling based on the data size.
464
+ Plot size scaling at fixed GPU count (previous weak-scaling behavior).
275
465
  """
276
466
  dataframes, available_gpu_counts, _ = clean_up_csv(
277
467
  csv_files,
@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from jax import make_jaxpr
10
- from jax.experimental.shard_map import shard_map
9
+ from jax import make_jaxpr, shard_map
11
10
  from jax.sharding import NamedSharding
12
11
  from jax.sharding import PartitionSpec as P
13
12
  from jaxtyping import Array
@@ -25,8 +24,13 @@ class Timer:
25
24
  ):
26
25
  self.jit_time = 0.0
27
26
  self.times = []
28
- self.profiling_data = {}
29
- self.compiled_code = {}
27
+ self.profiling_data = {
28
+ 'generated_code': 'N/A',
29
+ 'argument_size': 'N/A',
30
+ 'output_size': 'N/A',
31
+ 'temp_size': 'N/A',
32
+ }
33
+ self.compiled_code = {'JAXPR': 'N/A', 'LOWERED': 'N/A', 'COMPILED': 'N/A'}
30
34
  self.save_jaxpr = save_jaxpr
31
35
  self.compile_info = compile_info
32
36
  self.jax_fn = jax_fn
@@ -181,10 +185,10 @@ class Timer:
181
185
  mean_time = np.mean(times_array)
182
186
  std_time = np.std(times_array)
183
187
  last_time = times_array[-1]
184
- generated_code = self.profiling_data['generated_code']
185
- argument_size = self.profiling_data['argument_size']
186
- output_size = self.profiling_data['output_size']
187
- temp_size = self.profiling_data['temp_size']
188
+ generated_code = self.profiling_data.get('generated_code', 'N/A')
189
+ argument_size = self.profiling_data.get('argument_size', 'N/A')
190
+ output_size = self.profiling_data.get('output_size', 'N/A')
191
+ temp_size = self.profiling_data.get('temp_size', 'N/A')
188
192
 
189
193
  csv_line = (
190
194
  f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.13
4
- Summary: HPC Plotter and profiler for benchmarking data made for JAX
3
+ Version: 0.3.0
4
+ Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
7
7
  Version 3, 29 June 2007
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
679
679
  <https://www.gnu.org/licenses/why-not-lgpl.html>.
680
680
 
681
681
  Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
682
- Keywords: jax,hpc,profiler,plotter,benchmarking
682
+ Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
683
683
  Classifier: Development Status :: 4 - Beta
684
684
  Classifier: Intended Audience :: Developers
685
685
  Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
@@ -698,10 +698,22 @@ Requires-Dist: pandas
698
698
  Requires-Dist: matplotlib
699
699
  Requires-Dist: seaborn
700
700
  Requires-Dist: tabulate
701
+ Requires-Dist: adjustText
702
+ Requires-Dist: jax>=0.4.0
703
+ Requires-Dist: jaxtyping
704
+ Provides-Extra: test
705
+ Requires-Dist: pytest; extra == "test"
706
+ Requires-Dist: pytest-cov; extra == "test"
701
707
  Dynamic: license-file
702
708
 
703
709
  # JAX HPC Profiler
704
710
 
711
+ [![Build](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
712
+ [![Code Formatting](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
713
+ [![Tests](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml/badge.svg)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
714
+ [![Notebooks](https://img.shields.io/github/actions/workflow/status/ASKabalan/jax-hpc-profiler/notebooks.yml?logo=jupyter&label=notebooks)](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
715
+ [![GPLv3 License](https://img.shields.io/badge/License-GPL%20v3-yellow.svg)](https://www.gnu.org/licenses/gpl-3.0)
716
+
705
717
  JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
706
718
 
707
719
  ## Table of Contents
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
883
895
  - `-db, --dark_bg`: Use dark background for plotting.
884
896
  - `-pd, --print_decompositions`: Print decompositions on plot (experimental).
885
897
  - `-b, --backends`: List of backends to include. This argument can be multiple ones.
886
- - `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
898
+ - `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
899
+ - `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
900
+ - `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
901
+ - `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
902
+ - `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
887
903
  - `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
888
904
 
905
+ ### Weak scaling CLI example
906
+
907
+ For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
908
+
909
+ ```bash
910
+ jax-hpc-profiler plot \
911
+ -f MYDATA.csv \
912
+ -pt mean_time \
913
+ -sc Weak \
914
+ -g 1 2 4 8 \
915
+ -d 32 64 128 256 \
916
+ --weak_ideal_line
917
+ ```
918
+
919
+ This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
920
+
889
921
  ## Examples
890
922
 
891
923
  The repository includes examples for both profiling and plotting.
@@ -12,4 +12,6 @@ src/jax_hpc_profiler.egg-info/SOURCES.txt
12
12
  src/jax_hpc_profiler.egg-info/dependency_links.txt
13
13
  src/jax_hpc_profiler.egg-info/entry_points.txt
14
14
  src/jax_hpc_profiler.egg-info/requires.txt
15
- src/jax_hpc_profiler.egg-info/top_level.txt
15
+ src/jax_hpc_profiler.egg-info/top_level.txt
16
+ tests/test_plotting.py
17
+ tests/test_timer.py
@@ -0,0 +1,12 @@
1
+ numpy
2
+ pandas
3
+ matplotlib
4
+ seaborn
5
+ tabulate
6
+ adjustText
7
+ jax>=0.4.0
8
+ jaxtyping
9
+
10
+ [test]
11
+ pytest
12
+ pytest-cov
@@ -0,0 +1,112 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import pytest
4
+
5
+ from jax_hpc_profiler.plotting import (
6
+ configure_axes,
7
+ plot_strong_scaling,
8
+ plot_weak_fixed_scaling,
9
+ plot_weak_scaling,
10
+ )
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_plt():
15
+ with patch('jax_hpc_profiler.plotting.plt') as mock:
16
+ mock_fig = MagicMock()
17
+ mock_ax = MagicMock()
18
+ mock.subplots.return_value = (mock_fig, mock_ax)
19
+ yield mock
20
+
21
+
22
+ @pytest.fixture
23
+ def sample_csv(tmp_path):
24
+ csv_file = tmp_path / 'test_data.csv'
25
+ # Create data matching Timer.report format (19 columns)
26
+ # function,precision,x,y,z,px,py,backend,nodes,jit,min,max,mean,std,last,gen_code,arg_size,
27
+ # out_size,tmp_size
28
+ data = [
29
+ 'fun1,float32,100,100,100,1,1,NCCL,1,0.1,1.0,1.2,1.1,0.01,1.1,1000,1000,1000,1000',
30
+ 'fun1,float32,200,200,200,1,1,NCCL,1,0.2,2.0,2.4,2.2,0.02,2.2,2000,2000,2000,2000',
31
+ 'fun1,float32,400,400,400,1,1,NCCL,1,0.4,4.0,4.8,4.4,0.04,4.4,4000,4000,4000,4000',
32
+ # Add entries for strong scaling (same x, diff nodes/gpus)
33
+ # nodes is used as 'gpus' in some logic?
34
+ # utils.py: df['gpus'] = df['px'] * df['py']
35
+ # So we vary px*py.
36
+ 'fun2,float32,1000,1000,1000,1,1,NCCL,1,0.1,10.0,12.0,11.0,0.1,11.0,1000,1000,1000,1000',
37
+ 'fun2,float32,1000,1000,1000,2,1,NCCL,2,0.1,5.0,6.0,5.5,0.05,5.5,1000,1000,1000,1000',
38
+ 'fun2,float32,1000,1000,1000,2,2,NCCL,4,0.1,2.5,3.0,2.75,0.025,2.75,1000,1000,1000,1000',
39
+ ]
40
+ with open(csv_file, 'w') as f:
41
+ f.write('\n'.join(data) + '\n')
42
+ return str(csv_file)
43
+
44
+
45
+ @pytest.fixture
46
+ def mock_adjust_text():
47
+ with patch('jax_hpc_profiler.plotting.adjust_text') as mock:
48
+ yield mock
49
+
50
+
51
+ def test_plot_weak_fixed_scaling(mock_plt, sample_csv):
52
+ # WeakFixed: vary data size (x), fixed GPUs (calculated from px*py).
53
+ # In sample_csv fun1: gpus=1, x=[100, 200, 400]
54
+
55
+ plot_weak_fixed_scaling(
56
+ csv_files=[sample_csv],
57
+ fixed_gpu_size=[1],
58
+ fixed_data_size=[100, 200, 400],
59
+ functions=['fun1'],
60
+ xlabel='Data Size',
61
+ title='Weak Fixed Scaling',
62
+ )
63
+
64
+ assert mock_plt.show.called or mock_plt.savefig.called
65
+
66
+
67
+ def test_plot_strong_scaling(mock_plt, sample_csv):
68
+ # Strong: fixed data size (x), vary GPUs.
69
+ # In sample_csv fun2: x=1000, gpus=[1, 2, 4]
70
+
71
+ plot_strong_scaling(
72
+ csv_files=[sample_csv],
73
+ fixed_data_size=[1000],
74
+ fixed_gpu_size=[1, 2, 4],
75
+ functions=['fun2'],
76
+ xlabel='GPUs',
77
+ title='Strong Scaling',
78
+ )
79
+
80
+ assert mock_plt.show.called or mock_plt.savefig.called
81
+
82
+
83
+ def test_plot_weak_scaling(mock_plt, mock_adjust_text, sample_csv):
84
+ # Weak: explicit pairs of (gpus, data_size)
85
+ # We have (1, 100) for fun1.
86
+
87
+ plot_weak_scaling(
88
+ csv_files=[sample_csv],
89
+ fixed_gpu_size=[1],
90
+ fixed_data_size=[100],
91
+ functions=['fun1'],
92
+ xlabel='GPUs',
93
+ title='Weak Scaling',
94
+ )
95
+
96
+ assert mock_plt.show.called or mock_plt.savefig.called
97
+ assert (
98
+ mock_adjust_text.called or not mock_adjust_text.called
99
+ ) # It's okay if called or not, just don't crash.
100
+
101
+
102
+ def test_configure_axes():
103
+ # Test the helper directly
104
+ mock_ax = MagicMock()
105
+ configure_axes(
106
+ mock_ax, x_values=[1, 2, 4], y_values=[10, 5, 2.5], title='Test Plot', xlabel='X Label'
107
+ )
108
+
109
+ mock_ax.set_title.assert_called_with('Test Plot')
110
+ mock_ax.set_xlabel.assert_called_with('X Label')
111
+ mock_ax.set_xscale.assert_called()
112
+ mock_ax.set_yscale.assert_called_with('symlog')
@@ -0,0 +1,103 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ from jax_hpc_profiler import Timer
6
+
7
+
8
+ # Simple JAX function
9
+ @jax.jit
10
+ def simple_add(x, y):
11
+ return x + y
12
+
13
+
14
+ def test_timer_initialization():
15
+ timer = Timer()
16
+ assert timer.save_jaxpr is False
17
+ assert timer.compile_info is True
18
+ assert timer.jit_time == 0.0
19
+ assert len(timer.times) == 0
20
+
21
+
22
+ def test_chrono_jit():
23
+ timer = Timer(save_jaxpr=True, compile_info=True)
24
+ x = jnp.ones((10, 10))
25
+ y = jnp.ones((10, 10))
26
+
27
+ out = timer.chrono_jit(simple_add, x, y)
28
+ assert out.shape == (10, 10)
29
+ assert timer.jit_time > 0
30
+ assert timer.compiled_code['JAXPR'] != 'N/A'
31
+ assert timer.compiled_code['COMPILED'] != 'N/A'
32
+ assert timer.profiling_data['generated_code'] != 'N/A'
33
+
34
+
35
+ def test_chrono_fun():
36
+ timer = Timer()
37
+ x = jnp.ones((10, 10))
38
+ y = jnp.ones((10, 10))
39
+
40
+ # Run once to compile (if using jit externally, but here we just call the function)
41
+ out = timer.chrono_fun(simple_add, x, y)
42
+ assert out.shape == (10, 10)
43
+ assert len(timer.times) == 1
44
+
45
+
46
+ def test_report(tmp_path):
47
+ timer = Timer(save_jaxpr=False)
48
+ x = jnp.ones((10, 10))
49
+ y = jnp.ones((10, 10))
50
+
51
+ timer.chrono_jit(simple_add, x, y)
52
+ for _ in range(5):
53
+ timer.chrono_fun(simple_add, x, y)
54
+
55
+ csv_file = tmp_path / 'report.csv'
56
+ md_file = tmp_path / 'report.md'
57
+
58
+ # We specify nodes=1 explicitely to match what we expect in some tests, though default is 1
59
+ timer.report(
60
+ str(csv_file),
61
+ function='simple_add',
62
+ x=10,
63
+ y=10,
64
+ precision='float32',
65
+ md_filename=str(md_file),
66
+ extra_info={'custom_key': 'custom_val'},
67
+ )
68
+
69
+ assert csv_file.exists()
70
+ assert md_file.exists()
71
+
72
+ with open(csv_file) as f:
73
+ content = f.read()
74
+ assert 'simple_add' in content
75
+ assert 'float32' in content
76
+
77
+ with open(md_file) as f:
78
+ content = f.read()
79
+ assert '# Reporting for simple_add' in content
80
+ assert 'custom_key' in content
81
+ assert 'custom_val' in content
82
+
83
+
84
+ def test_normalize_memory_units():
85
+ timer = Timer()
86
+ # Mocking internal state as if we had data
87
+ timer.jax_fn = True
88
+ timer.compile_info = True
89
+
90
+ assert timer._normalize_memory_units(100) == '100.00 B'
91
+ assert timer._normalize_memory_units(1024) == '1.00 KB'
92
+ assert timer._normalize_memory_units(1024**2) == '1.00 MB'
93
+ assert timer._normalize_memory_units(1024**3) == '1.00 GB'
94
+
95
+
96
+ def test_get_mean_times():
97
+ timer = Timer()
98
+ timer.times = [10.0, 20.0, 30.0]
99
+
100
+ means = timer._get_mean_times()
101
+ assert isinstance(means, np.ndarray)
102
+ assert len(means) == 3
103
+ assert means[0] == 10.0
@@ -1,5 +0,0 @@
1
- numpy
2
- pandas
3
- matplotlib
4
- seaborn
5
- tabulate