jax-hpc-profiler 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/create_argparse.py +2 -1
- jax_hpc_profiler/main.py +13 -5
- jax_hpc_profiler/plotting.py +13 -10
- jax_hpc_profiler/timer.py +2 -10
- jax_hpc_profiler/utils.py +3 -3
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.8.dist-info}/METADATA +1 -1
- jax_hpc_profiler-0.2.8.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.8.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.6.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.8.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.8.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -67,7 +67,6 @@ def create_argparser():
|
|
|
67
67
|
plot_parser.add_argument('-fn',
|
|
68
68
|
'--function_name',
|
|
69
69
|
nargs='+',
|
|
70
|
-
default=['FFT'],
|
|
71
70
|
help='Function names to filter')
|
|
72
71
|
|
|
73
72
|
# Time or memory related arguments
|
|
@@ -135,6 +134,8 @@ def create_argparser():
|
|
|
135
134
|
'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
|
|
136
135
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
|
|
137
136
|
|
|
137
|
+
subparsers.add_parser('label_help',help='Label customization help')
|
|
138
|
+
|
|
138
139
|
args = parser.parse_args()
|
|
139
140
|
|
|
140
141
|
# if command was plot, then check if pdim_strategy is validat
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -11,6 +11,15 @@ def main():
|
|
|
11
11
|
if args.command == 'concat':
|
|
12
12
|
input_dir, output_dir = args.input, args.output
|
|
13
13
|
concatenate_csvs(input_dir, output_dir)
|
|
14
|
+
elif args.command == 'label_help':
|
|
15
|
+
print(f"Customize the label text for the plot. using these commands.")
|
|
16
|
+
print(' -- %m% or %methodname%: method name')
|
|
17
|
+
print(' -- %f% or %function%: function name')
|
|
18
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
19
|
+
print(' -- %pr% or %precision%: precision')
|
|
20
|
+
print(' -- %b% or %backend%: backend')
|
|
21
|
+
print(' -- %p% or %pdims%: pdims')
|
|
22
|
+
print(' -- %n% or %node%: node')
|
|
14
23
|
elif args.command == 'plot':
|
|
15
24
|
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
16
25
|
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
@@ -23,11 +32,10 @@ def main():
|
|
|
23
32
|
f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
|
|
24
33
|
)
|
|
25
34
|
# filter back the requested data sizes and gpus
|
|
26
|
-
|
|
27
|
-
args.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
35
|
+
|
|
36
|
+
args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
|
|
37
|
+
args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
|
|
38
|
+
|
|
31
39
|
if len(args.gpus) == 0:
|
|
32
40
|
print(f"No dataframes found for the given GPUs. Exiting...")
|
|
33
41
|
sys.exit(1)
|
jax_hpc_profiler/plotting.py
CHANGED
|
@@ -7,10 +7,10 @@ import pandas as pd
|
|
|
7
7
|
from matplotlib.axes import Axes
|
|
8
8
|
from matplotlib.patches import FancyBboxPatch
|
|
9
9
|
|
|
10
|
-
from .utils import inspect_df, plot_with_pdims_strategy
|
|
10
|
+
from .utils import inspect_df, plot_with_pdims_strategy, inspect_data
|
|
11
11
|
|
|
12
12
|
np.seterr(divide='ignore')
|
|
13
|
-
plt.rcParams.update({'font.size':
|
|
13
|
+
plt.rcParams.update({'font.size': 15})
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def configure_axes(ax: Axes,
|
|
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
|
|
|
38
38
|
f2 = lambda x: np.log2(x)
|
|
39
39
|
g2 = lambda x: 2**x
|
|
40
40
|
ax.set_xlim([min(x_values), max(x_values)])
|
|
41
|
-
y_min, y_max = min(y_values) * 0.
|
|
41
|
+
y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
|
|
42
42
|
ax.set_title(title)
|
|
43
43
|
ax.set_ylim([y_min, y_max])
|
|
44
44
|
ax.set_xscale('function', functions=(f2, g2))
|
|
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
|
|
|
57
57
|
ax.legend(loc='lower center',
|
|
58
58
|
bbox_to_anchor=(0.5, 0.05),
|
|
59
59
|
ncol=4,
|
|
60
|
+
fontsize="x-large",
|
|
60
61
|
prop={'size': 14})
|
|
61
62
|
|
|
62
63
|
|
|
@@ -117,13 +118,13 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
117
118
|
axs = axs.flatten()
|
|
118
119
|
else:
|
|
119
120
|
axs = [axs]
|
|
120
|
-
|
|
121
|
+
|
|
121
122
|
for i, fixed_size in enumerate(fixed_sizes):
|
|
122
123
|
ax: Axes = axs[i]
|
|
123
124
|
|
|
125
|
+
x_values = []
|
|
126
|
+
y_values = []
|
|
124
127
|
for method, df in dataframes.items():
|
|
125
|
-
x_values = []
|
|
126
|
-
y_values = []
|
|
127
128
|
|
|
128
129
|
filtered_method_df = df[df[fixed_column] == int(fixed_size)]
|
|
129
130
|
if filtered_method_df.empty:
|
|
@@ -136,6 +137,7 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
136
137
|
plot_columns)
|
|
137
138
|
|
|
138
139
|
for backend, precision, function, plot_column in combinations:
|
|
140
|
+
|
|
139
141
|
filtered_params_df = filtered_method_df[
|
|
140
142
|
(filtered_method_df['backend'] == backend)
|
|
141
143
|
& (filtered_method_df['precision'] == precision) &
|
|
@@ -148,10 +150,11 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
148
150
|
|
|
149
151
|
x_values.extend(x_vals)
|
|
150
152
|
y_values.extend(y_vals)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
153
|
+
|
|
154
|
+
if len(x_values) != 0:
|
|
155
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
156
|
+
configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
|
|
157
|
+
plotting_memory, memory_units)
|
|
155
158
|
|
|
156
159
|
for i in range(num_subplots, num_rows * num_cols):
|
|
157
160
|
fig.delaxes(axs[i])
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -23,16 +23,11 @@ class Timer:
|
|
|
23
23
|
self.compiled_code = {}
|
|
24
24
|
self.save_jaxpr = save_jaxpr
|
|
25
25
|
|
|
26
|
-
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
-
if cost_analysis is None:
|
|
28
|
-
return None
|
|
29
|
-
return cost_analysis[0]['flops']
|
|
30
|
-
|
|
31
26
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
27
|
|
|
33
28
|
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
29
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
-
factor = int(np.log10(memory_analysis) // 3)
|
|
30
|
+
factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
|
|
36
31
|
|
|
37
32
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
33
|
|
|
@@ -62,11 +57,9 @@ class Timer:
|
|
|
62
57
|
compiled = lowered.compile()
|
|
63
58
|
memory_analysis = self._read_memory_analysis(
|
|
64
59
|
compiled.memory_analysis())
|
|
65
|
-
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
66
60
|
|
|
67
61
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
68
62
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
69
|
-
self.profiling_data["FLOPS"] = cost_analysis
|
|
70
63
|
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
64
|
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
65
|
self.profiling_data["output_size"] = memory_analysis[2]
|
|
@@ -145,7 +138,6 @@ class Timer:
|
|
|
145
138
|
std_time = np.std(times_array)
|
|
146
139
|
last_time = times_array[-1]
|
|
147
140
|
|
|
148
|
-
flops = self.profiling_data["FLOPS"]
|
|
149
141
|
generated_code = self.profiling_data["generated_code"]
|
|
150
142
|
argument_size = self.profiling_data["argument_size"]
|
|
151
143
|
output_size = self.profiling_data["output_size"]
|
|
@@ -154,7 +146,7 @@ class Timer:
|
|
|
154
146
|
csv_line = (
|
|
155
147
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
156
148
|
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
157
|
-
f"{generated_code},{argument_size},{output_size},{temp_size}
|
|
149
|
+
f"{generated_code},{argument_size},{output_size},{temp_size}\n"
|
|
158
150
|
)
|
|
159
151
|
|
|
160
152
|
with open(csv_filename, 'a') as f:
|
jax_hpc_profiler/utils.py
CHANGED
|
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
|
123
123
|
# Sort all and keep fastest
|
|
124
124
|
sorted_dfs = []
|
|
125
125
|
for _, group in df_decomp:
|
|
126
|
-
group.sort_values(by=[y_col], inplace=True, ascending=
|
|
126
|
+
group.sort_values(by=[y_col], inplace=True, ascending=True)
|
|
127
127
|
sorted_dfs.append(group.iloc[0])
|
|
128
128
|
sorted_df = pd.DataFrame(sorted_dfs)
|
|
129
129
|
label_params.update({
|
|
@@ -336,12 +336,11 @@ def clean_up_csv(
|
|
|
336
336
|
# Filter data sizes
|
|
337
337
|
if data_sizes:
|
|
338
338
|
df = df[df['x'].isin(data_sizes)]
|
|
339
|
-
|
|
339
|
+
|
|
340
340
|
# Filter pdims
|
|
341
341
|
if pdims:
|
|
342
342
|
px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
|
|
343
343
|
df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
|
|
344
|
-
|
|
345
344
|
# convert memory units columns to remquested memory_units
|
|
346
345
|
match memory_units:
|
|
347
346
|
case 'KB':
|
|
@@ -385,6 +384,7 @@ def clean_up_csv(
|
|
|
385
384
|
df.drop(columns=['px', 'py'], inplace=True)
|
|
386
385
|
if not 'plot_all' in pdims_strategy:
|
|
387
386
|
df = df[df['decomp'].isin(pdims_strategy)]
|
|
387
|
+
|
|
388
388
|
# check available gpus in dataset
|
|
389
389
|
available_gpu_counts.update(df['gpus'].unique())
|
|
390
390
|
available_data_sizes.update(df['x'].unique())
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=jnWpyfZOB5w5L4BNIU5euKbUt__wfGPyCjUITrEwScM,8431
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=tUXnNHwQSSCqA6XxLd4MoV2gyFIC7ncB-uvVc2INhms,14140
|
|
7
|
+
jax_hpc_profiler-0.2.8.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.8.dist-info/METADATA,sha256=NDOKx7RzKr4Z4mJnY457-4uxaP2TDhwWYH5rZRxfukQ,49250
|
|
9
|
+
jax_hpc_profiler-0.2.8.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
|
|
10
|
+
jax_hpc_profiler-0.2.8.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.8.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.8.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=dEicamRYqJ6GGdgcph2bwAbmdxPkS4tS12xZ4c0X_Pk,6484
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=j6oH5IZz12VJik2cE7EQ3a9tAW9C8xl7D2QLW8Bkz3s,8617
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
-
jax_hpc_profiler-0.2.6.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.6.dist-info/METADATA,sha256=AgLXyb89gdxgeyDv02_P5oqIvoffdTh1mc3zFPUVBAU,49250
|
|
9
|
-
jax_hpc_profiler-0.2.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
-
jax_hpc_profiler-0.2.6.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.6.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|