jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +9 -3
- jax_hpc_profiler/create_argparse.py +128 -120
- jax_hpc_profiler/main.py +41 -22
- jax_hpc_profiler/plotting.py +250 -68
- jax_hpc_profiler/timer.py +117 -126
- jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/METADATA +36 -4
- jax_hpc_profiler-0.3.0.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.12.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/utils.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Dict, List, Optional, Tuple
|
|
3
3
|
|
|
4
|
-
import matplotlib.pyplot as plt
|
|
5
|
-
import numpy as np
|
|
6
4
|
import pandas as pd
|
|
7
5
|
from matplotlib.axes import Axes
|
|
8
6
|
|
|
9
7
|
|
|
10
|
-
def inspect_data(dataframes: Dict[str, pd.DataFrame]):
|
|
8
|
+
def inspect_data(dataframes: Dict[str, pd.DataFrame]) -> None:
|
|
11
9
|
"""
|
|
12
10
|
Inspect the dataframes.
|
|
13
11
|
|
|
@@ -16,16 +14,16 @@ def inspect_data(dataframes: Dict[str, pd.DataFrame]):
|
|
|
16
14
|
dataframes : Dict[str, pd.DataFrame]
|
|
17
15
|
Dictionary of method names to dataframes.
|
|
18
16
|
"""
|
|
19
|
-
print(
|
|
20
|
-
print(
|
|
21
|
-
print(
|
|
17
|
+
print('=' * 80)
|
|
18
|
+
print('Inspecting dataframes...')
|
|
19
|
+
print('=' * 80)
|
|
22
20
|
for method, df in dataframes.items():
|
|
23
|
-
print(f
|
|
21
|
+
print(f'Method: {method}')
|
|
24
22
|
inspect_df(df)
|
|
25
|
-
print(
|
|
23
|
+
print('=' * 80)
|
|
26
24
|
|
|
27
25
|
|
|
28
|
-
def inspect_df(df: pd.DataFrame):
|
|
26
|
+
def inspect_df(df: pd.DataFrame) -> None:
|
|
29
27
|
"""
|
|
30
28
|
Inspect the dataframe.
|
|
31
29
|
|
|
@@ -35,24 +33,24 @@ def inspect_df(df: pd.DataFrame):
|
|
|
35
33
|
The dataframe to inspect.
|
|
36
34
|
"""
|
|
37
35
|
print(df.to_markdown())
|
|
38
|
-
print(
|
|
36
|
+
print('-' * 80)
|
|
39
37
|
|
|
40
38
|
|
|
41
39
|
params_dict = {
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
40
|
+
'%pn%': '%plot_name%',
|
|
41
|
+
'%m%': '%method_name%',
|
|
42
|
+
'%n%': '%node%',
|
|
43
|
+
'%b%': '%backend%',
|
|
44
|
+
'%f%': '%function%',
|
|
45
|
+
'%cn%': '%column_name%',
|
|
46
|
+
'%pr%': '%precision%',
|
|
47
|
+
'%p%': '%decomposition%',
|
|
48
|
+
'%d%': '%data_size%',
|
|
49
|
+
'%g%': '%nb_gpu%',
|
|
52
50
|
}
|
|
53
51
|
|
|
54
52
|
|
|
55
|
-
def expand_label(label_template: str, params: dict) -> str:
|
|
53
|
+
def expand_label(label_template: str, params: dict[str, str]) -> str:
|
|
56
54
|
"""
|
|
57
55
|
Expand the label template with the provided parameters.
|
|
58
56
|
|
|
@@ -72,14 +70,20 @@ def expand_label(label_template: str, params: dict) -> str:
|
|
|
72
70
|
label_template = label_template.replace(key, value)
|
|
73
71
|
|
|
74
72
|
for key, value in params.items():
|
|
75
|
-
label_template = label_template.replace(f
|
|
73
|
+
label_template = label_template.replace(f'%{key}%', str(value))
|
|
76
74
|
return label_template
|
|
77
75
|
|
|
78
76
|
|
|
79
|
-
def plot_with_pdims_strategy(
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
77
|
+
def plot_with_pdims_strategy(
|
|
78
|
+
ax: Axes,
|
|
79
|
+
df: pd.DataFrame,
|
|
80
|
+
method: str,
|
|
81
|
+
pdims_strategy: List[str],
|
|
82
|
+
print_decompositions: bool,
|
|
83
|
+
x_col: str,
|
|
84
|
+
y_col: str,
|
|
85
|
+
label_template: str,
|
|
86
|
+
) -> Optional[Tuple[List[float], List[float]]]:
|
|
83
87
|
"""
|
|
84
88
|
Plot the data based on the pdims strategy.
|
|
85
89
|
|
|
@@ -109,12 +113,12 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
|
109
113
|
Template for plot labels with placeholders.
|
|
110
114
|
"""
|
|
111
115
|
label_params = {
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
'plot_name': y_col,
|
|
117
|
+
'method_name': method,
|
|
118
|
+
'backend': df['backend'].values[0],
|
|
119
|
+
'node': df['nodes'].values[0],
|
|
120
|
+
'precision': df['precision'].values[0],
|
|
121
|
+
'function': df['function'].values[0],
|
|
118
122
|
}
|
|
119
123
|
|
|
120
124
|
if 'plot_fastest' in pdims_strategy:
|
|
@@ -126,51 +130,50 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
|
126
130
|
group.sort_values(by=[y_col], inplace=True, ascending=True)
|
|
127
131
|
sorted_dfs.append(group.iloc[0])
|
|
128
132
|
sorted_df = pd.DataFrame(sorted_dfs)
|
|
129
|
-
label_params.update({
|
|
130
|
-
"decomposition":
|
|
131
|
-
f"{group['px'].values[0]}x{group['py'].values[0]}"
|
|
132
|
-
})
|
|
133
|
+
label_params.update({'decomposition': f'{group["px"].values[0]}x{group["py"].values[0]}'})
|
|
133
134
|
label = expand_label(label_template, label_params)
|
|
134
|
-
ax.plot(
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
135
|
+
ax.plot(
|
|
136
|
+
sorted_df[x_col].values,
|
|
137
|
+
sorted_df[y_col],
|
|
138
|
+
marker='o',
|
|
139
|
+
linestyle='-',
|
|
140
|
+
label=label,
|
|
141
|
+
)
|
|
139
142
|
# TODO(wassim) : this is not working very well
|
|
140
143
|
if print_decompositions:
|
|
141
|
-
for j, (px, py) in enumerate(zip(sorted_df['px'],
|
|
142
|
-
sorted_df['py'])):
|
|
144
|
+
for j, (px, py) in enumerate(zip(sorted_df['px'], sorted_df['py'])):
|
|
143
145
|
ax.annotate(
|
|
144
|
-
f
|
|
146
|
+
f'{px}x{py}',
|
|
145
147
|
(sorted_df[x_col].values[j], sorted_df[y_col].values[j]),
|
|
146
|
-
textcoords=
|
|
148
|
+
textcoords='offset points',
|
|
147
149
|
xytext=(0, 10),
|
|
148
150
|
ha='center',
|
|
149
|
-
color='red' if j == 0 else 'white'
|
|
151
|
+
color='red' if j == 0 else 'white',
|
|
152
|
+
)
|
|
150
153
|
return sorted_df[x_col].values, sorted_df[y_col].values
|
|
151
154
|
|
|
152
|
-
elif any(
|
|
153
|
-
|
|
155
|
+
elif any(
|
|
156
|
+
strategy in pdims_strategy for strategy in ['plot_all', 'slab_yz', 'slab_xy', 'pencils']
|
|
157
|
+
):
|
|
154
158
|
df_decomp = df.groupby(['decomp'])
|
|
155
159
|
x_values = []
|
|
156
160
|
y_values = []
|
|
157
161
|
for _, group in df_decomp:
|
|
158
|
-
group.drop_duplicates(subset=[x_col, 'decomp'],
|
|
159
|
-
keep='last',
|
|
160
|
-
inplace=True)
|
|
162
|
+
group.drop_duplicates(subset=[x_col, 'decomp'], keep='last', inplace=True)
|
|
161
163
|
group.sort_values(by=[x_col], inplace=True, ascending=False)
|
|
162
164
|
# filter decomp based on pdims_strategy
|
|
163
|
-
if 'plot_all' not in pdims_strategy and group['decomp'].values[
|
|
164
|
-
0] not in pdims_strategy:
|
|
165
|
+
if 'plot_all' not in pdims_strategy and group['decomp'].values[0] not in pdims_strategy:
|
|
165
166
|
continue
|
|
166
167
|
|
|
167
|
-
label_params.update({
|
|
168
|
+
label_params.update({'decomposition': group['decomp'].values[0]})
|
|
168
169
|
label = expand_label(label_template, label_params)
|
|
169
|
-
ax.plot(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
170
|
+
ax.plot(
|
|
171
|
+
group[x_col].values,
|
|
172
|
+
group[y_col],
|
|
173
|
+
marker='o',
|
|
174
|
+
linestyle='-',
|
|
175
|
+
label=label,
|
|
176
|
+
)
|
|
174
177
|
x_values.extend(group[x_col].values)
|
|
175
178
|
y_values.extend(group[y_col].values)
|
|
176
179
|
return x_values, y_values
|
|
@@ -204,39 +207,63 @@ def concatenate_csvs(root_dir: str, output_dir: str):
|
|
|
204
207
|
if file.endswith('.csv'):
|
|
205
208
|
csv_file_path = os.path.join(root, file)
|
|
206
209
|
print(f'Concatenating {csv_file_path}...')
|
|
207
|
-
df = pd.read_csv(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
210
|
+
df = pd.read_csv(
|
|
211
|
+
csv_file_path,
|
|
212
|
+
header=None,
|
|
213
|
+
names=[
|
|
214
|
+
'function',
|
|
215
|
+
'precision',
|
|
216
|
+
'x',
|
|
217
|
+
'y',
|
|
218
|
+
'z',
|
|
219
|
+
'px',
|
|
220
|
+
'py',
|
|
221
|
+
'backend',
|
|
222
|
+
'nodes',
|
|
223
|
+
'jit_time',
|
|
224
|
+
'min_time',
|
|
225
|
+
'max_time',
|
|
226
|
+
'mean_time',
|
|
227
|
+
'std_div',
|
|
228
|
+
'last_time',
|
|
229
|
+
'generated_code',
|
|
230
|
+
'argument_size',
|
|
231
|
+
'output_size',
|
|
232
|
+
'temp_size',
|
|
233
|
+
'flops',
|
|
234
|
+
],
|
|
235
|
+
index_col=False,
|
|
236
|
+
)
|
|
218
237
|
if file not in combined_dfs:
|
|
219
238
|
combined_dfs[file] = df
|
|
220
239
|
else:
|
|
221
|
-
combined_dfs[file] = pd.concat(
|
|
222
|
-
[combined_dfs[file], df], ignore_index=True)
|
|
240
|
+
combined_dfs[file] = pd.concat([combined_dfs[file], df], ignore_index=True)
|
|
223
241
|
|
|
224
242
|
# Remove duplicates based on specified columns and save
|
|
225
243
|
for file_name, combined_df in combined_dfs.items():
|
|
226
|
-
combined_df.drop_duplicates(
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
244
|
+
combined_df.drop_duplicates(
|
|
245
|
+
subset=[
|
|
246
|
+
'function',
|
|
247
|
+
'precision',
|
|
248
|
+
'x',
|
|
249
|
+
'y',
|
|
250
|
+
'z',
|
|
251
|
+
'px',
|
|
252
|
+
'py',
|
|
253
|
+
'backend',
|
|
254
|
+
'nodes',
|
|
255
|
+
],
|
|
256
|
+
keep='last',
|
|
257
|
+
inplace=True,
|
|
258
|
+
)
|
|
232
259
|
|
|
233
260
|
gpu_output_dir = os.path.join(output_dir, gpu)
|
|
234
261
|
if not os.path.exists(gpu_output_dir):
|
|
235
|
-
print(f
|
|
262
|
+
print(f'Creating directory {gpu_output_dir}')
|
|
236
263
|
os.makedirs(gpu_output_dir)
|
|
237
264
|
|
|
238
265
|
output_file = os.path.join(gpu_output_dir, file_name)
|
|
239
|
-
print(f
|
|
266
|
+
print(f'Writing file to {output_file}...')
|
|
240
267
|
combined_df.to_csv(output_file, index=False)
|
|
241
268
|
|
|
242
269
|
|
|
@@ -287,42 +314,59 @@ def clean_up_csv(
|
|
|
287
314
|
file_name = os.path.splitext(os.path.basename(csv_file))[0]
|
|
288
315
|
ext = os.path.splitext(os.path.basename(csv_file))[1]
|
|
289
316
|
if ext != '.csv':
|
|
290
|
-
print(f
|
|
317
|
+
print(f'Ignoring {csv_file} as it is not a CSV file')
|
|
291
318
|
continue
|
|
292
319
|
|
|
293
|
-
df = pd.read_csv(
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
320
|
+
df = pd.read_csv(
|
|
321
|
+
csv_file,
|
|
322
|
+
header=None,
|
|
323
|
+
skiprows=0,
|
|
324
|
+
names=[
|
|
325
|
+
'function',
|
|
326
|
+
'precision',
|
|
327
|
+
'x',
|
|
328
|
+
'y',
|
|
329
|
+
'z',
|
|
330
|
+
'px',
|
|
331
|
+
'py',
|
|
332
|
+
'backend',
|
|
333
|
+
'nodes',
|
|
334
|
+
'jit_time',
|
|
335
|
+
'min_time',
|
|
336
|
+
'max_time',
|
|
337
|
+
'mean_time',
|
|
338
|
+
'std_div',
|
|
339
|
+
'last_time',
|
|
340
|
+
'generated_code',
|
|
341
|
+
'argument_size',
|
|
342
|
+
'output_size',
|
|
343
|
+
'temp_size',
|
|
344
|
+
'flops',
|
|
345
|
+
],
|
|
346
|
+
dtype={
|
|
347
|
+
'function': str,
|
|
348
|
+
'precision': str,
|
|
349
|
+
'x': int,
|
|
350
|
+
'y': int,
|
|
351
|
+
'z': int,
|
|
352
|
+
'px': int,
|
|
353
|
+
'py': int,
|
|
354
|
+
'backend': str,
|
|
355
|
+
'nodes': int,
|
|
356
|
+
'jit_time': float,
|
|
357
|
+
'min_time': float,
|
|
358
|
+
'max_time': float,
|
|
359
|
+
'mean_time': float,
|
|
360
|
+
'std_div': float,
|
|
361
|
+
'last_time': float,
|
|
362
|
+
'generated_code': float,
|
|
363
|
+
'argument_size': float,
|
|
364
|
+
'output_size': float,
|
|
365
|
+
'temp_size': float,
|
|
366
|
+
'flops': float,
|
|
367
|
+
},
|
|
368
|
+
index_col=False,
|
|
369
|
+
)
|
|
326
370
|
|
|
327
371
|
# Filter precisions
|
|
328
372
|
if precisions:
|
|
@@ -360,18 +404,32 @@ def clean_up_csv(
|
|
|
360
404
|
df['output_size'] = df['output_size'] / factor
|
|
361
405
|
df['temp_size'] = df['temp_size'] / factor
|
|
362
406
|
# in case of the same test is run multiple times, keep the last one
|
|
363
|
-
df = df.drop_duplicates(
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
407
|
+
df = df.drop_duplicates(
|
|
408
|
+
subset=[
|
|
409
|
+
'function',
|
|
410
|
+
'precision',
|
|
411
|
+
'x',
|
|
412
|
+
'y',
|
|
413
|
+
'z',
|
|
414
|
+
'px',
|
|
415
|
+
'py',
|
|
416
|
+
'backend',
|
|
417
|
+
'nodes',
|
|
418
|
+
],
|
|
419
|
+
keep='last',
|
|
420
|
+
)
|
|
368
421
|
|
|
369
422
|
df['gpus'] = df['px'] * df['py']
|
|
370
423
|
|
|
371
424
|
if gpus:
|
|
372
425
|
df = df[df['gpus'].isin(gpus)]
|
|
373
426
|
|
|
374
|
-
if
|
|
427
|
+
if (
|
|
428
|
+
'plot_all' in pdims_strategy
|
|
429
|
+
or 'slab_yz' in pdims_strategy
|
|
430
|
+
or 'slab_xy' in pdims_strategy
|
|
431
|
+
or 'pencils' in pdims_strategy
|
|
432
|
+
):
|
|
375
433
|
|
|
376
434
|
def get_decomp_from_px_py(row):
|
|
377
435
|
if row['px'] > 1 and row['py'] == 1:
|
|
@@ -383,7 +441,7 @@ def clean_up_csv(
|
|
|
383
441
|
|
|
384
442
|
df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
|
|
385
443
|
df.drop(columns=['px', 'py'], inplace=True)
|
|
386
|
-
if
|
|
444
|
+
if 'plot_all' not in pdims_strategy:
|
|
387
445
|
df = df[df['decomp'].isin(pdims_strategy)]
|
|
388
446
|
|
|
389
447
|
# check available gpus in dataset
|
|
@@ -395,17 +453,18 @@ def clean_up_csv(
|
|
|
395
453
|
else:
|
|
396
454
|
dataframes[file_name] = pd.concat([dataframes[file_name], df])
|
|
397
455
|
|
|
398
|
-
print(f
|
|
399
|
-
print(
|
|
400
|
-
f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
|
|
401
|
-
)
|
|
456
|
+
print(f'requested GPUS: {gpus} available GPUS: {available_gpu_counts}')
|
|
457
|
+
print(f'requested data sizes: {data_sizes} available data sizes: {available_data_sizes}')
|
|
402
458
|
|
|
403
|
-
available_gpu_counts = (
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
459
|
+
available_gpu_counts = (
|
|
460
|
+
available_gpu_counts
|
|
461
|
+
if gpus is None
|
|
462
|
+
else [gpu for gpu in gpus if gpu in available_gpu_counts]
|
|
463
|
+
)
|
|
464
|
+
available_data_sizes = (
|
|
465
|
+
available_data_sizes
|
|
466
|
+
if data_sizes is None
|
|
467
|
+
else [data_size for data_size in data_sizes if data_size in available_data_sizes]
|
|
468
|
+
)
|
|
410
469
|
|
|
411
470
|
return dataframes, available_gpu_counts, available_data_sizes
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jax_hpc_profiler
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: A comprehensive benchmarking and profiling tool designed for JAX in HPC environments, offering automated instrumentation, strong/weak scaling analysis, and performance visualization.
|
|
5
5
|
Author: Wassim Kabalan
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
7
7
|
Version 3, 29 June 2007
|
|
@@ -679,7 +679,7 @@ License: GNU GENERAL PUBLIC LICENSE
|
|
|
679
679
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
|
680
680
|
|
|
681
681
|
Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
|
|
682
|
-
Keywords: jax,hpc,
|
|
682
|
+
Keywords: jax,hpc,profiling,benchmarking,visualization,scaling,performance-analysis,gpu,distributed-computing
|
|
683
683
|
Classifier: Development Status :: 4 - Beta
|
|
684
684
|
Classifier: Intended Audience :: Developers
|
|
685
685
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
@@ -698,10 +698,22 @@ Requires-Dist: pandas
|
|
|
698
698
|
Requires-Dist: matplotlib
|
|
699
699
|
Requires-Dist: seaborn
|
|
700
700
|
Requires-Dist: tabulate
|
|
701
|
+
Requires-Dist: adjustText
|
|
702
|
+
Requires-Dist: jax>=0.4.0
|
|
703
|
+
Requires-Dist: jaxtyping
|
|
704
|
+
Provides-Extra: test
|
|
705
|
+
Requires-Dist: pytest; extra == "test"
|
|
706
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
701
707
|
Dynamic: license-file
|
|
702
708
|
|
|
703
709
|
# JAX HPC Profiler
|
|
704
710
|
|
|
711
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/python-publish.yml)
|
|
712
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/formatting.yml)
|
|
713
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/tests.yml)
|
|
714
|
+
[](https://github.com/ASKabalan/jax-hpc-profiler/actions/workflows/notebooks.yml)
|
|
715
|
+
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
716
|
+
|
|
705
717
|
JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
|
|
706
718
|
|
|
707
719
|
## Table of Contents
|
|
@@ -883,9 +895,29 @@ jax-hpc-profiler plot -f <csv_files> [options]
|
|
|
883
895
|
- `-db, --dark_bg`: Use dark background for plotting.
|
|
884
896
|
- `-pd, --print_decompositions`: Print decompositions on plot (experimental).
|
|
885
897
|
- `-b, --backends`: List of backends to include. This argument can be multiple ones.
|
|
886
|
-
- `-sc, --scaling`: Scaling type (`Weak`, `
|
|
898
|
+
- `-sc, --scaling`: Scaling type (`Strong`, `Weak`, `WeakFixed`).
|
|
899
|
+
- `Strong`: strong scaling with fixed global problem size(s), plotting runtime (or memory) versus number of GPUs.
|
|
900
|
+
- `Weak`: true weak scaling with explicit `(gpus, data_size)` sequences; requires that `-g/--gpus` and `-d/--data_size` are both provided and have the same length, and plots runtime (or memory) versus number of GPUs on a single figure.
|
|
901
|
+
- `WeakFixed`: size scaling at fixed GPU count (previous weak behavior); plots runtime (or memory) versus data size, grouped by number of GPUs.
|
|
902
|
+
- `--weak_ideal_line`: When using `-sc Weak`, overlay an ideal flat line based on the smallest-GPU runtime for the first plotted weak-scaling curve.
|
|
887
903
|
- `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
|
|
888
904
|
|
|
905
|
+
### Weak scaling CLI example
|
|
906
|
+
|
|
907
|
+
For a weak-scaling run where work per GPU is kept approximately constant, you might provide matching GPU and data-size sequences, for example:
|
|
908
|
+
|
|
909
|
+
```bash
|
|
910
|
+
jax-hpc-profiler plot \
|
|
911
|
+
-f MYDATA.csv \
|
|
912
|
+
-pt mean_time \
|
|
913
|
+
-sc Weak \
|
|
914
|
+
-g 1 2 4 8 \
|
|
915
|
+
-d 32 64 128 256 \
|
|
916
|
+
--weak_ideal_line
|
|
917
|
+
```
|
|
918
|
+
|
|
919
|
+
This will produce a single weak-scaling plot of runtime versus number of GPUs, using the points `(gpus, data_size) = (1, 32), (2, 64), (4, 128), (8, 256)` and overlay an ideal weak-scaling reference line.
|
|
920
|
+
|
|
889
921
|
## Examples
|
|
890
922
|
|
|
891
923
|
The repository includes examples for both profiling and plotting.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=GIYY_D3CqhTzEf9Bh8ihar_-ntDYSQpp2utFIgYRbYg,444
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=ff2e8PvHmbbyF13OH2FTLlpnIUGp9xP8kS--XuJuhZ4,6582
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=ehqU6HwqhjKLs_34tmzWFQU2G-kSiVmhKJ1HIAw-6Lg,3262
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=vQsykw4JJNZn6Z6IR5_VABXEHKBhESQdAoAAN4dOaPk,15998
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=5coHheE6eaviLCZsPuXodbl7pYW9ora-GU9M6PJqRNQ,10442
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=IfGDbKldJXiDhxb02IxmQV51SFIBYLDUL7Se_OtEOkc,14963
|
|
7
|
+
jax_hpc_profiler-0.3.0.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.3.0.dist-info/METADATA,sha256=EI-Qb9STk1q_mJC2WMaBGLaqdj8Adtp0FpNeRRx6NNQ,51620
|
|
9
|
+
jax_hpc_profiler-0.3.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
10
|
+
jax_hpc_profiler-0.3.0.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.3.0.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.3.0.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=6DVpYuj908L01vk09-l7BLU6dW4OGgTepFR123fsawM,6314
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=dwOik2rJw5YV6ocQ-EE32iFOPlq2_3CHHAAuJJFt65Q,2286
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=R0mjUhV_Q-qi02mlxWiR241sxr58USBSykwFdjBa-oM,9484
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=4zc5HlJwepMK633BDz0iLTLWcLsvPdd6M1SL0-qs4js,10554
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=7i8qPfKogp8nGaGdyJ2-fbQomhIZqn73PQ14qldpFTc,14657
|
|
7
|
-
jax_hpc_profiler-0.2.12.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.12.dist-info/METADATA,sha256=bIriXXGdfJ7yTIdsuIdZgig3WpshBibBYKtGDjukr8A,49186
|
|
9
|
-
jax_hpc_profiler-0.2.12.dist-info/WHEEL,sha256=tTnHoFhvKQHCh4jz3yCn0WPTYIy7wXx3CJtJ7SJGV7c,91
|
|
10
|
-
jax_hpc_profiler-0.2.12.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.12.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|