jax-hpc-profiler 0.2.13__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +2 -1
- jax_hpc_profiler/create_argparse.py +21 -2
- jax_hpc_profiler/main.py +28 -5
- jax_hpc_profiler/plotting.py +191 -1
- jax_hpc_profiler/timer.py +12 -8
- {jax_hpc_profiler-0.2.13.dist-info → jax_hpc_profiler-0.3.0.dist-info}/METADATA +36 -4
- jax_hpc_profiler-0.3.0.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.13.dist-info → jax_hpc_profiler-0.3.0.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.13.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.13.dist-info → jax_hpc_profiler-0.3.0.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.13.dist-info → jax_hpc_profiler-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {jax_hpc_profiler-0.2.13.dist-info → jax_hpc_profiler-0.3.0.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .create_argparse import create_argparser
|
|
2
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
3
|
from .timer import Timer
|
|
4
4
|
from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
|
|
5
5
|
|
|
@@ -7,6 +7,7 @@ __all__ = [
|
|
|
7
7
|
'create_argparser',
|
|
8
8
|
'plot_strong_scaling',
|
|
9
9
|
'plot_weak_scaling',
|
|
10
|
+
'plot_weak_fixed_scaling',
|
|
10
11
|
'Timer',
|
|
11
12
|
'clean_up_csv',
|
|
12
13
|
'concatenate_csvs',
|
|
@@ -135,9 +135,24 @@ def create_argparser():
|
|
|
135
135
|
plot_parser.add_argument(
|
|
136
136
|
'-sc',
|
|
137
137
|
'--scaling',
|
|
138
|
-
choices=['Weak', 'Strong', 'w', 's'],
|
|
138
|
+
choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
|
|
139
139
|
required=True,
|
|
140
|
-
help='Scaling type (Weak or
|
|
140
|
+
help='Scaling type (Strong, Weak, or WeakFixed)',
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Weak-scaling specific options
|
|
144
|
+
plot_parser.add_argument(
|
|
145
|
+
'--weak_ideal_line',
|
|
146
|
+
action='store_true',
|
|
147
|
+
help='Overlay an ideal flat line for weak scaling (Weak mode only)',
|
|
148
|
+
)
|
|
149
|
+
plot_parser.add_argument(
|
|
150
|
+
'--weak_reverse_axes',
|
|
151
|
+
action='store_true',
|
|
152
|
+
help=(
|
|
153
|
+
'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
|
|
154
|
+
'of data size. Requires --gpus and --data_size with equal lengths.'
|
|
155
|
+
),
|
|
141
156
|
)
|
|
142
157
|
|
|
143
158
|
# Label customization argument
|
|
@@ -196,4 +211,8 @@ def create_argparser():
|
|
|
196
211
|
else:
|
|
197
212
|
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
198
213
|
|
|
214
|
+
# Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
|
|
215
|
+
# data_size are provided and have matching lengths. For Strong and
|
|
216
|
+
# WeakFixed, gpus/data_size remain optional as before.
|
|
217
|
+
|
|
199
218
|
return args
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .create_argparse import create_argparser
|
|
2
|
-
from .plotting import plot_strong_scaling, plot_weak_scaling
|
|
2
|
+
from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
|
|
3
3
|
from .utils import concatenate_csvs
|
|
4
4
|
|
|
5
5
|
|
|
@@ -19,7 +19,8 @@ def main():
|
|
|
19
19
|
print(' -- %p% or %pdims%: pdims')
|
|
20
20
|
print(' -- %n% or %node%: node')
|
|
21
21
|
elif args.command == 'plot':
|
|
22
|
-
|
|
22
|
+
scaling = args.scaling.lower()
|
|
23
|
+
if scaling in ('weak', 'w'):
|
|
23
24
|
plot_weak_scaling(
|
|
24
25
|
args.csv_files,
|
|
25
26
|
args.gpus,
|
|
@@ -33,13 +34,15 @@ def main():
|
|
|
33
34
|
args.plot_columns,
|
|
34
35
|
args.memory_units,
|
|
35
36
|
args.label_text,
|
|
37
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
36
38
|
args.title,
|
|
37
|
-
args.label_text,
|
|
38
39
|
args.figure_size,
|
|
39
40
|
args.dark_bg,
|
|
40
41
|
args.output,
|
|
42
|
+
args.weak_ideal_line,
|
|
43
|
+
args.weak_reverse_axes,
|
|
41
44
|
)
|
|
42
|
-
elif
|
|
45
|
+
elif scaling in ('strong', 's'):
|
|
43
46
|
plot_strong_scaling(
|
|
44
47
|
args.csv_files,
|
|
45
48
|
args.gpus,
|
|
@@ -53,8 +56,28 @@ def main():
|
|
|
53
56
|
args.plot_columns,
|
|
54
57
|
args.memory_units,
|
|
55
58
|
args.label_text,
|
|
56
|
-
args.
|
|
59
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
|
|
60
|
+
args.title if getattr(args, 'title', None) is not None else 'Data sizes',
|
|
61
|
+
args.figure_size,
|
|
62
|
+
args.dark_bg,
|
|
63
|
+
args.output,
|
|
64
|
+
)
|
|
65
|
+
elif scaling in ('weakfixed', 'wf'):
|
|
66
|
+
plot_weak_fixed_scaling(
|
|
67
|
+
args.csv_files,
|
|
68
|
+
args.gpus,
|
|
69
|
+
args.data_size,
|
|
70
|
+
args.function_name,
|
|
71
|
+
args.precision,
|
|
72
|
+
args.filter_pdims,
|
|
73
|
+
args.pdim_strategy,
|
|
74
|
+
args.print_decompositions,
|
|
75
|
+
args.backends,
|
|
76
|
+
args.plot_columns,
|
|
77
|
+
args.memory_units,
|
|
57
78
|
args.label_text,
|
|
79
|
+
args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
|
|
80
|
+
args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
|
|
58
81
|
args.figure_size,
|
|
59
82
|
args.dark_bg,
|
|
60
83
|
args.output,
|
jax_hpc_profiler/plotting.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
|
+
from adjustText import adjust_text
|
|
7
8
|
from matplotlib.axes import Axes
|
|
8
9
|
from matplotlib.patches import FancyBboxPatch
|
|
9
10
|
|
|
@@ -252,6 +253,195 @@ def plot_strong_scaling(
|
|
|
252
253
|
|
|
253
254
|
|
|
254
255
|
def plot_weak_scaling(
|
|
256
|
+
csv_files: List[str],
|
|
257
|
+
fixed_gpu_size: Optional[List[int]] = None,
|
|
258
|
+
fixed_data_size: Optional[List[int]] = None,
|
|
259
|
+
functions: Optional[List[str]] = None,
|
|
260
|
+
precisions: Optional[List[str]] = None,
|
|
261
|
+
pdims: Optional[List[str]] = None,
|
|
262
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
263
|
+
print_decompositions: bool = False,
|
|
264
|
+
backends: Optional[List[str]] = None,
|
|
265
|
+
plot_columns: List[str] = ['mean_time'],
|
|
266
|
+
memory_units: str = 'bytes',
|
|
267
|
+
label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
|
|
268
|
+
xlabel: str = 'Number of GPUs',
|
|
269
|
+
title: Optional[str] = None,
|
|
270
|
+
figure_size: tuple = (6, 4),
|
|
271
|
+
dark_bg: bool = False,
|
|
272
|
+
output: Optional[str] = None,
|
|
273
|
+
ideal_line: bool = False,
|
|
274
|
+
reverse_axes: bool = False,
|
|
275
|
+
):
|
|
276
|
+
"""
|
|
277
|
+
Plot true weak scaling: runtime vs GPUs for explicit (gpus, data size) sequences.
|
|
278
|
+
|
|
279
|
+
Both ``fixed_gpu_size`` and ``fixed_data_size`` must be provided and have the same length,
|
|
280
|
+
representing explicit weak-scaling pairs (gpus[i], data_size[i]).
|
|
281
|
+
|
|
282
|
+
reverse_axes:
|
|
283
|
+
- False (default): x-axis is GPUs, y-axis is time; points are annotated with
|
|
284
|
+
``N=<data_size>``.
|
|
285
|
+
- True: x-axis is data size, y-axis is time; points are annotated with ``GPUs=<gpu_count>``.
|
|
286
|
+
"""
|
|
287
|
+
if fixed_gpu_size is None or fixed_data_size is None:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
'Weak scaling requires both fixed_gpu_size (gpus) and fixed_data_size (problem sizes).'
|
|
290
|
+
)
|
|
291
|
+
if len(fixed_gpu_size) != len(fixed_data_size):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
'Weak scaling requires fixed_gpu_size and fixed_data_size lists of equal length.'
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
gpu_to_data = {int(g): int(d) for g, d in zip(fixed_gpu_size, fixed_data_size)}
|
|
297
|
+
data_to_gpu = {int(d): int(g) for g, d in zip(fixed_gpu_size, fixed_data_size)}
|
|
298
|
+
x_col = 'x' if reverse_axes else 'gpus'
|
|
299
|
+
|
|
300
|
+
dataframes, _, _ = clean_up_csv(
|
|
301
|
+
csv_files,
|
|
302
|
+
precisions,
|
|
303
|
+
functions,
|
|
304
|
+
fixed_gpu_size,
|
|
305
|
+
fixed_data_size,
|
|
306
|
+
pdims,
|
|
307
|
+
pdims_strategy,
|
|
308
|
+
backends,
|
|
309
|
+
memory_units,
|
|
310
|
+
)
|
|
311
|
+
if len(dataframes) == 0:
|
|
312
|
+
print('No dataframes found for the given arguments. Exiting...')
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
if dark_bg:
|
|
316
|
+
plt.style.use('dark_background')
|
|
317
|
+
|
|
318
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
319
|
+
|
|
320
|
+
x_values: List[float] = []
|
|
321
|
+
y_values: List[float] = []
|
|
322
|
+
annotations: List = []
|
|
323
|
+
ideal_line_plotted = False
|
|
324
|
+
|
|
325
|
+
for method, df in dataframes.items():
|
|
326
|
+
# Determine parameter sets from the filtered dataframe if not provided
|
|
327
|
+
local_functions = pd.unique(df['function']) if functions is None else functions
|
|
328
|
+
local_precisions = pd.unique(df['precision']) if precisions is None else precisions
|
|
329
|
+
local_backends = pd.unique(df['backend']) if backends is None else backends
|
|
330
|
+
|
|
331
|
+
combinations = product(local_backends, local_precisions, local_functions, plot_columns)
|
|
332
|
+
|
|
333
|
+
for backend, precision, function, plot_column in combinations:
|
|
334
|
+
base_df = df[
|
|
335
|
+
(df['backend'] == backend)
|
|
336
|
+
& (df['precision'] == precision)
|
|
337
|
+
& (df['function'] == function)
|
|
338
|
+
]
|
|
339
|
+
if base_df.empty:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# Keep only rows matching any of the (gpus, x) pairs
|
|
343
|
+
mask = pd.Series(False, index=base_df.index)
|
|
344
|
+
for g, d in zip(fixed_gpu_size, fixed_data_size):
|
|
345
|
+
mask |= (base_df['gpus'] == int(g)) & (base_df['x'] == int(d))
|
|
346
|
+
|
|
347
|
+
filtered_params_df = base_df[mask]
|
|
348
|
+
if filtered_params_df.empty:
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
x_vals, y_vals = plot_with_pdims_strategy(
|
|
352
|
+
ax,
|
|
353
|
+
filtered_params_df,
|
|
354
|
+
method,
|
|
355
|
+
pdims_strategy,
|
|
356
|
+
print_decompositions,
|
|
357
|
+
x_col,
|
|
358
|
+
plot_column,
|
|
359
|
+
label_text,
|
|
360
|
+
)
|
|
361
|
+
if x_vals is None or len(x_vals) == 0:
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
x_arr = np.asarray(x_vals).reshape(-1)
|
|
365
|
+
y_arr = np.asarray(y_vals).reshape(-1)
|
|
366
|
+
|
|
367
|
+
# Annotate every point with data size or GPU count depending on axis choice.
|
|
368
|
+
# Use plain data coordinates for the text; adjust_text will then only move
|
|
369
|
+
# the labels slightly (mostly vertically) to avoid overlap.
|
|
370
|
+
for xv, yv in zip(x_arr, y_arr):
|
|
371
|
+
if reverse_axes:
|
|
372
|
+
gpu = data_to_gpu.get(int(xv))
|
|
373
|
+
if gpu is None:
|
|
374
|
+
continue
|
|
375
|
+
label = f'GPUs={gpu}'
|
|
376
|
+
else:
|
|
377
|
+
data_size = gpu_to_data.get(int(xv))
|
|
378
|
+
if data_size is None:
|
|
379
|
+
continue
|
|
380
|
+
label = f'N={data_size}'
|
|
381
|
+
|
|
382
|
+
text_obj = ax.text(
|
|
383
|
+
float(xv),
|
|
384
|
+
float(yv),
|
|
385
|
+
label,
|
|
386
|
+
ha='center',
|
|
387
|
+
va='bottom',
|
|
388
|
+
fontsize='small',
|
|
389
|
+
clip_on=True,
|
|
390
|
+
)
|
|
391
|
+
annotations.append(text_obj)
|
|
392
|
+
|
|
393
|
+
x_values.extend(x_arr.tolist())
|
|
394
|
+
y_values.extend(y_arr.tolist())
|
|
395
|
+
|
|
396
|
+
if ideal_line and not ideal_line_plotted:
|
|
397
|
+
# Use the smallest x value in this curve as baseline
|
|
398
|
+
baseline_index = np.argmin(x_arr)
|
|
399
|
+
baseline_y = y_arr[baseline_index]
|
|
400
|
+
ax.hlines(
|
|
401
|
+
baseline_y,
|
|
402
|
+
xmin=float(np.min(x_arr)),
|
|
403
|
+
xmax=float(np.max(x_arr)),
|
|
404
|
+
colors='gray',
|
|
405
|
+
linestyles='dashed',
|
|
406
|
+
label='Ideal weak scaling',
|
|
407
|
+
)
|
|
408
|
+
ideal_line_plotted = True
|
|
409
|
+
y_values.append(float(baseline_y))
|
|
410
|
+
|
|
411
|
+
if x_values:
|
|
412
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
413
|
+
figure_title = title if title is not None else 'Weak scaling'
|
|
414
|
+
configure_axes(
|
|
415
|
+
ax,
|
|
416
|
+
x_values,
|
|
417
|
+
y_values,
|
|
418
|
+
figure_title,
|
|
419
|
+
xlabel,
|
|
420
|
+
plotting_memory,
|
|
421
|
+
memory_units,
|
|
422
|
+
)
|
|
423
|
+
if annotations:
|
|
424
|
+
ax.figure.canvas.draw()
|
|
425
|
+
adjust_text(
|
|
426
|
+
annotations,
|
|
427
|
+
ax=ax,
|
|
428
|
+
# keep points aligned in x, only allow vertical motion
|
|
429
|
+
only_move={'text': 'y', 'static': 'y'},
|
|
430
|
+
expand=(1.02, 1.05),
|
|
431
|
+
force_text=(0.08, 0.2),
|
|
432
|
+
max_move=(0, 30),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
fig.tight_layout()
|
|
436
|
+
rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
|
|
437
|
+
fig.patches.append(rect)
|
|
438
|
+
if output is None:
|
|
439
|
+
plt.show()
|
|
440
|
+
else:
|
|
441
|
+
plt.savefig(output)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def plot_weak_fixed_scaling(
|
|
255
445
|
csv_files: List[str],
|
|
256
446
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
257
447
|
fixed_data_size: Optional[List[int]] = None,
|
|
@@ -271,7 +461,7 @@ def plot_weak_scaling(
|
|
|
271
461
|
output: Optional[str] = None,
|
|
272
462
|
):
|
|
273
463
|
"""
|
|
274
|
-
Plot
|
|
464
|
+
Plot size scaling at fixed GPU count (previous weak-scaling behavior).
|
|
275
465
|
"""
|
|
276
466
|
dataframes, available_gpu_counts, _ = clean_up_csv(
|
|
277
467
|
csv_files,
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -6,8 +6,7 @@ from typing import Any, Callable, Optional, Tuple
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
-
from jax import make_jaxpr
|
|
10
|
-
from jax.experimental.shard_map import shard_map
|
|
9
|
+
from jax import make_jaxpr, shard_map
|
|
11
10
|
from jax.sharding import NamedSharding
|
|
12
11
|
from jax.sharding import PartitionSpec as P
|
|
13
12
|
from jaxtyping import Array
|
|
@@ -25,8 +24,13 @@ class Timer:
|
|
|
25
24
|
):
|
|
26
25
|
self.jit_time = 0.0
|
|
27
26
|
self.times = []
|
|
28
|
-
self.profiling_data = {
|
|
29
|
-
|
|
27
|
+
self.profiling_data = {
|
|
28
|
+
'generated_code': 'N/A',
|
|
29
|
+
'argument_size': 'N/A',
|
|
30
|
+
'output_size': 'N/A',
|
|
31
|
+
'temp_size': 'N/A',
|
|
32
|
+
}
|
|
33
|
+
self.compiled_code = {'JAXPR': 'N/A', 'LOWERED': 'N/A', 'COMPILED': 'N/A'}
|
|
30
34
|
self.save_jaxpr = save_jaxpr
|
|
31
35
|
self.compile_info = compile_info
|
|
32
36
|
self.jax_fn = jax_fn
|
|
@@ -181,10 +185,10 @@ class Timer:
|
|
|
181
185
|
mean_time = np.mean(times_array)
|
|
182
186
|
std_time = np.std(times_array)
|
|
183
187
|
last_time = times_array[-1]
|
|
184
|
-
generated_code = self.profiling_data
|
|
185
|
-
argument_size = self.profiling_data
|
|
186
|
-
output_size = self.profiling_data
|
|
187
|
-
temp_size = self.profiling_data
|
|
188
|
+
generated_code = self.profiling_data.get('generated_code', 'N/A')
|
|
189
|
+
argument_size = self.profiling_data.get('argument_size', 'N/A')
|
|
190
|
+
output_size = self.profiling_data.get('output_size', 'N/A')
|
|
191
|
+
temp_size = self.profiling_data.get('temp_size', 'N/A')
|
|
188
192
|
|
|
189
193
|
csv_line = (
|
|
190
194
|
f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
681
|
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
-
Keywords: jax,hpc,
|
|
682
|
+
Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
|
|
683
683
|
Classifier: Development Status :: 4 - Beta
|
|
684
684
|
Classifier: Intended Audience :: Developers
|
|
685
685
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -698,10 +698,22 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
+
Requires-Dist: adjustText
|
|
702
|
+
Requires-Dist: jax>=0.4.0
|
|
703
|
+
Requires-Dist: jaxtyping
|
|
704
|
+
Provides-Extra: test
|
|
705
|
+
Requires-Dist: pytest; extra == "test"
|
|
706
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
701
707
|
Dynamic: license-file
|
|
702
708
|
|
|
703
709
|
# JAX HPC Profiler
|
|
704
710
|
|
|
711
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
712
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
713
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
714
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
715
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
716
|
+
|
|
705
717
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
706
718
|
|
|
707
719
|
## Table of Contents
|
|
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
883
895
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
884
896
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
885
897
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
886
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
898
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
899
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
900
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
901
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
902
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
887
903
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
888
904
|
|
|
905
|
+
### Weak scaling CLI example
|
|
906
|
+
|
|
907
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
908
|
+
|
|
909
|
+
```bash
|
|
910
|
+
jax-hpc-profiler plot \
|
|
911
|
+
-f MYDATA.csv \
|
|
912
|
+
-pt mean_time \
|
|
913
|
+
-sc Weak \
|
|
914
|
+
-g 1 2 4 8 \
|
|
915
|
+
-d 32 64 128 256 \
|
|
916
|
+
--weak_ideal_line
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
920
|
+
|
|
889
921
|
## Examples
|
|
890
922
|
|
|
891
923
|
The repository includes examples for both profiling and plotting.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=GIYY_D3CqhTzEf9Bh8ihar_-ntDYSQpp2utFIgYRbYg,444
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=ff2e8PvHmbbyF13OH2FTLlpnIUGp9xP8kS--XuJuhZ4,6582
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=ehqU6HwqhjKLs_34tmzWFQU2G-kSiVmhKJ1HIAw-6Lg,3262
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=vQsykw4JJNZn6Z6IR5_VABXEHKBhESQdAoAAN4dOaPk,15998
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=5coHheE6eaviLCZsPuXodbl7pYW9ora-GU9M6PJqRNQ,10442
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=IfGDbKldJXiDhxb02IxmQV51SFIBYLDUL7Se_OtEOkc,14963
|
|
7
|
+
jax_hpc_profiler-0.3.0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.3.0.dist-info/METADATA,sha256=EI-Qb9STk1q_mJC2WMaBGLaqdj8Adtp0FpNeRRx6NNQ,51620
|
|
9
|
+
jax_hpc_profiler-0.3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
10
|
+
jax_hpc_profiler-0.3.0.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.3.0.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.3.0.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=c2n33ZXVgUS8vo5xAEW-TcSi_UzJp616KdGEb3iO6p4,388
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=J1RF4n2e85QReoI_fqXxK5BMAUgzueHmObKOh4YHopE,5821
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=YPLkZCmtjzNoDrzTA4CWL8y39Spz3qbCS91eP2pqP5Y,2224
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=Lg157H3mrF3zHc4BIplddKu9f0viQkaQhtCCAQBxinE,9167
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=0lbJgNh3GT1dFOpNOA4Fwvsm9JNp-J1xDdLFaaQ6jaY,10237
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=IfGDbKldJXiDhxb02IxmQV51SFIBYLDUL7Se_OtEOkc,14963
|
|
7
|
-
jax_hpc_profiler-0.2.13.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.13.dist-info/METADATA,sha256=YyHfP98Vz8ya23YsRPV2rehbRoFsO3pziOgnoX5DitE,49186
|
|
9
|
-
jax_hpc_profiler-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
-
jax_hpc_profiler-0.2.13.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.13.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|