jax-hpc-profiler 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +9 -0
- jax_hpc_profiler/create_argparse.py +197 -0
- jax_hpc_profiler/main.py +65 -0
- jax_hpc_profiler/plotting.py +313 -0
- jax_hpc_profiler/timer.py +274 -0
- jax_hpc_profiler/utils.py +411 -0
- jax_hpc_profiler-0.2.10.dist-info/LICENSE +674 -0
- jax_hpc_profiler-0.2.10.dist-info/METADATA +902 -0
- jax_hpc_profiler-0.2.10.dist-info/RECORD +12 -0
- jax_hpc_profiler-0.2.10.dist-info/WHEEL +5 -0
- jax_hpc_profiler-0.2.10.dist-info/entry_points.txt +2 -0
- jax_hpc_profiler-0.2.10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from jax import make_jaxpr
|
|
10
|
+
from jax.experimental import mesh_utils
|
|
11
|
+
from jax.experimental.shard_map import shard_map
|
|
12
|
+
from jax.sharding import Mesh, NamedSharding
|
|
13
|
+
from jax.sharding import PartitionSpec as P
|
|
14
|
+
from tabulate import tabulate
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Timer:
|
|
18
|
+
|
|
19
|
+
def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
|
|
20
|
+
self.jit_time = 0.0
|
|
21
|
+
self.times = []
|
|
22
|
+
self.profiling_data = {}
|
|
23
|
+
self.compiled_code = {}
|
|
24
|
+
self.save_jaxpr = save_jaxpr
|
|
25
|
+
self.jax_fn = jax_fn
|
|
26
|
+
self.devices = devices
|
|
27
|
+
|
|
28
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
29
|
+
|
|
30
|
+
if not self.jax_fn:
|
|
31
|
+
return memory_analysis
|
|
32
|
+
|
|
33
|
+
sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
+
factor = 0 if memory_analysis == 0 else int(
|
|
36
|
+
np.log10(memory_analysis) // 3)
|
|
37
|
+
|
|
38
|
+
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
39
|
+
|
|
40
|
+
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
41
|
+
if memory_analysis is None:
|
|
42
|
+
return None, None, None, None
|
|
43
|
+
return (
|
|
44
|
+
memory_analysis.generated_code_size_in_bytes,
|
|
45
|
+
memory_analysis.argument_size_in_bytes,
|
|
46
|
+
memory_analysis.output_size_in_bytes,
|
|
47
|
+
memory_analysis.temp_size_in_bytes,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
51
|
+
start = time.perf_counter()
|
|
52
|
+
out = fun(*args)
|
|
53
|
+
if self.jax_fn:
|
|
54
|
+
if ndarray_arg is None:
|
|
55
|
+
out.block_until_ready()
|
|
56
|
+
else:
|
|
57
|
+
out[ndarray_arg].block_until_ready()
|
|
58
|
+
end = time.perf_counter()
|
|
59
|
+
self.jit_time = (end - start) * 1e3
|
|
60
|
+
|
|
61
|
+
if self.save_jaxpr:
|
|
62
|
+
jaxpr = make_jaxpr(fun)(*args)
|
|
63
|
+
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
64
|
+
|
|
65
|
+
if self.jax_fn:
|
|
66
|
+
lowered = jax.jit(fun).lower(*args)
|
|
67
|
+
compiled = lowered.compile()
|
|
68
|
+
memory_analysis = self._read_memory_analysis(
|
|
69
|
+
compiled.memory_analysis())
|
|
70
|
+
|
|
71
|
+
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
72
|
+
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
73
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
74
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
75
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
76
|
+
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
77
|
+
|
|
78
|
+
return out
|
|
79
|
+
|
|
80
|
+
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
81
|
+
start = time.perf_counter()
|
|
82
|
+
out = fun(*args)
|
|
83
|
+
if self.jax_fn:
|
|
84
|
+
if ndarray_arg is None:
|
|
85
|
+
out.block_until_ready()
|
|
86
|
+
else:
|
|
87
|
+
out[ndarray_arg].block_until_ready()
|
|
88
|
+
end = time.perf_counter()
|
|
89
|
+
self.times.append((end - start) * 1e3)
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
def _get_mean_times(self) -> np.ndarray:
|
|
93
|
+
if jax.device_count() == 1 or jax.process_count() == 1:
|
|
94
|
+
return np.array(self.times)
|
|
95
|
+
|
|
96
|
+
if self.devices is None:
|
|
97
|
+
self.devices = jax.devices()
|
|
98
|
+
|
|
99
|
+
mesh = jax.make_mesh((len(self.devices), ), ("x", ),
|
|
100
|
+
devices=self.devices)
|
|
101
|
+
sharding = NamedSharding(mesh, P("x"))
|
|
102
|
+
|
|
103
|
+
times_array = jnp.array(self.times)
|
|
104
|
+
global_shape = (jax.device_count(), times_array.shape[0])
|
|
105
|
+
global_times = jax.make_array_from_callback(
|
|
106
|
+
shape=global_shape,
|
|
107
|
+
sharding=sharding,
|
|
108
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@partial(shard_map,
|
|
112
|
+
mesh=mesh,
|
|
113
|
+
in_specs=P("x"),
|
|
114
|
+
out_specs=P(),
|
|
115
|
+
check_rep=False)
|
|
116
|
+
def get_mean_times(times):
|
|
117
|
+
return jax.lax.pmean(times, axis_name="x")
|
|
118
|
+
|
|
119
|
+
times_array = get_mean_times(global_times)
|
|
120
|
+
times_array.block_until_ready()
|
|
121
|
+
return np.array(times_array.addressable_data(0)[0])
|
|
122
|
+
|
|
123
|
+
def report(
|
|
124
|
+
self,
|
|
125
|
+
csv_filename: str,
|
|
126
|
+
function: str,
|
|
127
|
+
x: int,
|
|
128
|
+
y: int | None = None,
|
|
129
|
+
z: int | None = None,
|
|
130
|
+
precision: str = "float32",
|
|
131
|
+
px: int = 1,
|
|
132
|
+
py: int = 1,
|
|
133
|
+
backend: str = "NCCL",
|
|
134
|
+
nodes: int = 1,
|
|
135
|
+
md_filename: str | None = None,
|
|
136
|
+
npz_data: Optional[dict] = None,
|
|
137
|
+
extra_info: dict = {},
|
|
138
|
+
):
|
|
139
|
+
if self.jit_time == 0.0 and len(self.times) == 0:
|
|
140
|
+
print(f"No profiling data to report for {function}")
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
if md_filename is None:
|
|
144
|
+
dirname, filename = (
|
|
145
|
+
os.path.dirname(csv_filename),
|
|
146
|
+
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
147
|
+
)
|
|
148
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
149
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
150
|
+
md_filename = (
|
|
151
|
+
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if npz_data is not None:
|
|
155
|
+
dirname, filename = (
|
|
156
|
+
os.path.dirname(csv_filename),
|
|
157
|
+
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
158
|
+
)
|
|
159
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
160
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
161
|
+
npz_filename = (
|
|
162
|
+
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
|
|
163
|
+
)
|
|
164
|
+
np.savez(npz_filename, **npz_data)
|
|
165
|
+
|
|
166
|
+
y = x if y is None else y
|
|
167
|
+
z = x if z is None else z
|
|
168
|
+
|
|
169
|
+
times_array = self._get_mean_times()
|
|
170
|
+
if jax.process_index() == 0:
|
|
171
|
+
|
|
172
|
+
min_time = np.min(times_array)
|
|
173
|
+
max_time = np.max(times_array)
|
|
174
|
+
mean_time = np.mean(times_array)
|
|
175
|
+
std_time = np.std(times_array)
|
|
176
|
+
last_time = times_array[-1]
|
|
177
|
+
|
|
178
|
+
if self.jax_fn:
|
|
179
|
+
|
|
180
|
+
generated_code = self.profiling_data["generated_code"]
|
|
181
|
+
argument_size = self.profiling_data["argument_size"]
|
|
182
|
+
output_size = self.profiling_data["output_size"]
|
|
183
|
+
temp_size = self.profiling_data["temp_size"]
|
|
184
|
+
else:
|
|
185
|
+
generated_code = "N/A"
|
|
186
|
+
argument_size = "N/A"
|
|
187
|
+
output_size = "N/A"
|
|
188
|
+
temp_size = "N/A"
|
|
189
|
+
|
|
190
|
+
csv_line = (
|
|
191
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
192
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
193
|
+
f"{generated_code},{argument_size},{output_size},{temp_size}\n"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
with open(csv_filename, "a") as f:
|
|
197
|
+
f.write(csv_line)
|
|
198
|
+
|
|
199
|
+
param_dict = {
|
|
200
|
+
"Function": function,
|
|
201
|
+
"Precision": precision,
|
|
202
|
+
"X": x,
|
|
203
|
+
"Y": y,
|
|
204
|
+
"Z": z,
|
|
205
|
+
"PX": px,
|
|
206
|
+
"PY": py,
|
|
207
|
+
"Backend": backend,
|
|
208
|
+
"Nodes": nodes,
|
|
209
|
+
}
|
|
210
|
+
param_dict.update(extra_info)
|
|
211
|
+
profiling_result = {
|
|
212
|
+
"JIT Time": self.jit_time,
|
|
213
|
+
"Min Time": min_time,
|
|
214
|
+
"Max Time": max_time,
|
|
215
|
+
"Mean Time": mean_time,
|
|
216
|
+
"Std Time": std_time,
|
|
217
|
+
"Last Time": last_time,
|
|
218
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
219
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
220
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
221
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
222
|
+
}
|
|
223
|
+
iteration_runs = {}
|
|
224
|
+
for i in range(len(times_array)):
|
|
225
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
226
|
+
|
|
227
|
+
with open(md_filename, "w") as f:
|
|
228
|
+
f.write(f"# Reporting for {function}\n")
|
|
229
|
+
f.write(f"## Parameters\n")
|
|
230
|
+
f.write(
|
|
231
|
+
tabulate(
|
|
232
|
+
param_dict.items(),
|
|
233
|
+
headers=["Parameter", "Value"],
|
|
234
|
+
tablefmt="github",
|
|
235
|
+
))
|
|
236
|
+
f.write("\n---\n")
|
|
237
|
+
f.write(f"## Profiling Data\n")
|
|
238
|
+
f.write(
|
|
239
|
+
tabulate(
|
|
240
|
+
profiling_result.items(),
|
|
241
|
+
headers=["Parameter", "Value"],
|
|
242
|
+
tablefmt="github",
|
|
243
|
+
))
|
|
244
|
+
f.write("\n---\n")
|
|
245
|
+
f.write(f"## Iteration Runs\n")
|
|
246
|
+
f.write(
|
|
247
|
+
tabulate(
|
|
248
|
+
iteration_runs.items(),
|
|
249
|
+
headers=["Iteration", "Time"],
|
|
250
|
+
tablefmt="github",
|
|
251
|
+
))
|
|
252
|
+
if self.jax_fn:
|
|
253
|
+
f.write("\n---\n")
|
|
254
|
+
f.write(f"## Compiled Code\n")
|
|
255
|
+
f.write(f"```hlo\n")
|
|
256
|
+
f.write(self.compiled_code["COMPILED"])
|
|
257
|
+
f.write(f"\n```\n")
|
|
258
|
+
f.write("\n---\n")
|
|
259
|
+
f.write(f"## Lowered Code\n")
|
|
260
|
+
f.write(f"```hlo\n")
|
|
261
|
+
f.write(self.compiled_code["LOWERED"])
|
|
262
|
+
f.write(f"\n```\n")
|
|
263
|
+
f.write("\n---\n")
|
|
264
|
+
if self.save_jaxpr:
|
|
265
|
+
f.write(f"## JAXPR\n")
|
|
266
|
+
f.write(f"```haskel\n")
|
|
267
|
+
f.write(self.compiled_code["JAXPR"])
|
|
268
|
+
f.write(f"\n```\n")
|
|
269
|
+
|
|
270
|
+
# Reset the timer
|
|
271
|
+
self.jit_time = 0.0
|
|
272
|
+
self.times = []
|
|
273
|
+
self.profiling_data = {}
|
|
274
|
+
self.compiled_code = {}
|
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from matplotlib.axes import Axes
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def inspect_data(dataframes: Dict[str, pd.DataFrame]):
|
|
11
|
+
"""
|
|
12
|
+
Inspect the dataframes.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
dataframes : Dict[str, pd.DataFrame]
|
|
17
|
+
Dictionary of method names to dataframes.
|
|
18
|
+
"""
|
|
19
|
+
print("=" * 80)
|
|
20
|
+
print("Inspecting dataframes...")
|
|
21
|
+
print("=" * 80)
|
|
22
|
+
for method, df in dataframes.items():
|
|
23
|
+
print(f"Method: {method}")
|
|
24
|
+
inspect_df(df)
|
|
25
|
+
print("=" * 80)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def inspect_df(df: pd.DataFrame):
|
|
29
|
+
"""
|
|
30
|
+
Inspect the dataframe.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
df : pd.DataFrame
|
|
35
|
+
The dataframe to inspect.
|
|
36
|
+
"""
|
|
37
|
+
print(df.to_markdown())
|
|
38
|
+
print("-" * 80)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
params_dict = {
|
|
42
|
+
"%pn%": "%plot_name%",
|
|
43
|
+
"%m%": "%method_name%",
|
|
44
|
+
"%n%": "%node%",
|
|
45
|
+
"%b%": "%backend%",
|
|
46
|
+
"%f%": "%function%",
|
|
47
|
+
"%cn%": "%column_name%",
|
|
48
|
+
"%pr%": "%precision%",
|
|
49
|
+
"%p%": "%decomposition%",
|
|
50
|
+
"%d%": "%data_size%",
|
|
51
|
+
"%g%": "%nb_gpu%"
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def expand_label(label_template: str, params: dict) -> str:
|
|
56
|
+
"""
|
|
57
|
+
Expand the label template with the provided parameters.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
label_template : str
|
|
62
|
+
The label template with placeholders.
|
|
63
|
+
params : dict
|
|
64
|
+
The dictionary with actual values to replace placeholders.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
str
|
|
69
|
+
The expanded label.
|
|
70
|
+
"""
|
|
71
|
+
for key, value in params_dict.items():
|
|
72
|
+
label_template = label_template.replace(key, value)
|
|
73
|
+
|
|
74
|
+
for key, value in params.items():
|
|
75
|
+
label_template = label_template.replace(f"%{key}%", str(value))
|
|
76
|
+
return label_template
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
80
|
+
pdims_strategy: List[str],
|
|
81
|
+
print_decompositions: bool, x_col: str,
|
|
82
|
+
y_col: str, label_template: str):
|
|
83
|
+
"""
|
|
84
|
+
Plot the data based on the pdims strategy.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
ax : Axes
|
|
89
|
+
The axes to plot on.
|
|
90
|
+
df : pd.DataFrame
|
|
91
|
+
The dataframe to plot.
|
|
92
|
+
method : str
|
|
93
|
+
The method name.
|
|
94
|
+
backend : str
|
|
95
|
+
The backend name.
|
|
96
|
+
nodes_in_label : bool
|
|
97
|
+
Whether to include node names in labels.
|
|
98
|
+
pdims_strategy : List[str]
|
|
99
|
+
Strategy for plotting pdims.
|
|
100
|
+
print_decompositions : bool
|
|
101
|
+
Whether to print decompositions on the plot.
|
|
102
|
+
x_col : str
|
|
103
|
+
The column name for the x-axis values.
|
|
104
|
+
x_label : str
|
|
105
|
+
The label for the x-axis.
|
|
106
|
+
y_label : str
|
|
107
|
+
The label for the y-axis.
|
|
108
|
+
label_template : str
|
|
109
|
+
Template for plot labels with placeholders.
|
|
110
|
+
"""
|
|
111
|
+
label_params = {
|
|
112
|
+
"plot_name": y_col,
|
|
113
|
+
"method_name": method,
|
|
114
|
+
"backend": df['backend'].values[0],
|
|
115
|
+
"node": df['nodes'].values[0],
|
|
116
|
+
"precision": df['precision'].values[0],
|
|
117
|
+
"function": df['function'].values[0],
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if 'plot_fastest' in pdims_strategy:
|
|
121
|
+
df_decomp = df.groupby([x_col])
|
|
122
|
+
|
|
123
|
+
# Sort all and keep fastest
|
|
124
|
+
sorted_dfs = []
|
|
125
|
+
for _, group in df_decomp:
|
|
126
|
+
group.sort_values(by=[y_col], inplace=True, ascending=True)
|
|
127
|
+
sorted_dfs.append(group.iloc[0])
|
|
128
|
+
sorted_df = pd.DataFrame(sorted_dfs)
|
|
129
|
+
label_params.update({
|
|
130
|
+
"decomposition":
|
|
131
|
+
f"{group['px'].values[0]}x{group['py'].values[0]}"
|
|
132
|
+
})
|
|
133
|
+
label = expand_label(label_template, label_params)
|
|
134
|
+
ax.plot(sorted_df[x_col].values,
|
|
135
|
+
sorted_df[y_col],
|
|
136
|
+
marker='o',
|
|
137
|
+
linestyle='-',
|
|
138
|
+
label=label)
|
|
139
|
+
# TODO(wassim) : this is not working very well
|
|
140
|
+
if print_decompositions:
|
|
141
|
+
for j, (px, py) in enumerate(zip(sorted_df['px'],
|
|
142
|
+
sorted_df['py'])):
|
|
143
|
+
ax.annotate(
|
|
144
|
+
f"{px}x{py}",
|
|
145
|
+
(sorted_df[x_col].values[j], sorted_df[y_col].values[j]),
|
|
146
|
+
textcoords="offset points",
|
|
147
|
+
xytext=(0, 10),
|
|
148
|
+
ha='center',
|
|
149
|
+
color='red' if j == 0 else 'white')
|
|
150
|
+
return sorted_df[x_col].values, sorted_df[y_col].values
|
|
151
|
+
|
|
152
|
+
elif any(strategy in pdims_strategy
|
|
153
|
+
for strategy in ['plot_all', 'slab_yz', 'slab_xy', 'pencils']):
|
|
154
|
+
df_decomp = df.groupby(['decomp'])
|
|
155
|
+
x_values = []
|
|
156
|
+
y_values = []
|
|
157
|
+
for _, group in df_decomp:
|
|
158
|
+
group.drop_duplicates(subset=[x_col, 'decomp'],
|
|
159
|
+
keep='last',
|
|
160
|
+
inplace=True)
|
|
161
|
+
group.sort_values(by=[x_col], inplace=True, ascending=False)
|
|
162
|
+
# filter decomp based on pdims_strategy
|
|
163
|
+
if 'plot_all' not in pdims_strategy and group['decomp'].values[
|
|
164
|
+
0] not in pdims_strategy:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
label_params.update({"decomposition": group['decomp'].values[0]})
|
|
168
|
+
label = expand_label(label_template, label_params)
|
|
169
|
+
ax.plot(group[x_col].values,
|
|
170
|
+
group[y_col],
|
|
171
|
+
marker='o',
|
|
172
|
+
linestyle='-',
|
|
173
|
+
label=label)
|
|
174
|
+
x_values.extend(group[x_col].values)
|
|
175
|
+
y_values.extend(group[y_col].values)
|
|
176
|
+
return x_values, y_values
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def concatenate_csvs(root_dir: str, output_dir: str):
|
|
180
|
+
"""
|
|
181
|
+
Concatenate CSV files and remove duplicates by GPU type.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
root_dir : str
|
|
186
|
+
Root directory containing CSV files.
|
|
187
|
+
output_dir : str
|
|
188
|
+
Output directory to save concatenated CSV files.
|
|
189
|
+
"""
|
|
190
|
+
# Iterate over each GPU type directory
|
|
191
|
+
for gpu in os.listdir(root_dir):
|
|
192
|
+
gpu_dir = os.path.join(root_dir, gpu)
|
|
193
|
+
|
|
194
|
+
# Check if the GPU directory exists and is a directory
|
|
195
|
+
if not os.path.isdir(gpu_dir):
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
# Dictionary to hold combined dataframes for each CSV file name
|
|
199
|
+
combined_dfs = {}
|
|
200
|
+
|
|
201
|
+
# List CSV in directory and subdirectories
|
|
202
|
+
for root, dirs, files in os.walk(gpu_dir):
|
|
203
|
+
for file in files:
|
|
204
|
+
if file.endswith('.csv'):
|
|
205
|
+
csv_file_path = os.path.join(root, file)
|
|
206
|
+
print(f'Concatenating {csv_file_path}...')
|
|
207
|
+
df = pd.read_csv(csv_file_path,
|
|
208
|
+
header=None,
|
|
209
|
+
names=[
|
|
210
|
+
"function", "precision", "x", "y",
|
|
211
|
+
"z", "px", "py", "backend", "nodes",
|
|
212
|
+
"jit_time", "min_time", "max_time",
|
|
213
|
+
"mean_time", "std_div", "last_time",
|
|
214
|
+
"generated_code", "argument_size",
|
|
215
|
+
"output_size", "temp_size", "flops"
|
|
216
|
+
],
|
|
217
|
+
index_col=False)
|
|
218
|
+
if file not in combined_dfs:
|
|
219
|
+
combined_dfs[file] = df
|
|
220
|
+
else:
|
|
221
|
+
combined_dfs[file] = pd.concat(
|
|
222
|
+
[combined_dfs[file], df], ignore_index=True)
|
|
223
|
+
|
|
224
|
+
# Remove duplicates based on specified columns and save
|
|
225
|
+
for file_name, combined_df in combined_dfs.items():
|
|
226
|
+
combined_df.drop_duplicates(subset=[
|
|
227
|
+
"function", "precision", "x", "y", "z", "px", "py", "backend",
|
|
228
|
+
"nodes"
|
|
229
|
+
],
|
|
230
|
+
keep='last',
|
|
231
|
+
inplace=True)
|
|
232
|
+
|
|
233
|
+
gpu_output_dir = os.path.join(output_dir, gpu)
|
|
234
|
+
if not os.path.exists(gpu_output_dir):
|
|
235
|
+
print(f"Creating directory {gpu_output_dir}")
|
|
236
|
+
os.makedirs(gpu_output_dir)
|
|
237
|
+
|
|
238
|
+
output_file = os.path.join(gpu_output_dir, file_name)
|
|
239
|
+
print(f"Writing file to {output_file}...")
|
|
240
|
+
combined_df.to_csv(output_file, index=False)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def clean_up_csv(
|
|
244
|
+
csv_files: List[str],
|
|
245
|
+
precisions: Optional[List[str]] = None,
|
|
246
|
+
function_names: Optional[List[str]] = None,
|
|
247
|
+
gpus: Optional[List[int]] = None,
|
|
248
|
+
data_sizes: Optional[List[int]] = None,
|
|
249
|
+
pdims: Optional[List[str]] = None,
|
|
250
|
+
pdims_strategy: List[str] = ['plot_fastest'],
|
|
251
|
+
backends: Optional[List[str]] = None,
|
|
252
|
+
memory_units: str = 'KB',
|
|
253
|
+
) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
|
|
254
|
+
"""
|
|
255
|
+
Clean up and aggregate data from CSV files.
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
csv_files : List[str]
|
|
260
|
+
List of CSV files to process.
|
|
261
|
+
precisions : Optional[List[str]], optional
|
|
262
|
+
Precisions to filter by, by default None.
|
|
263
|
+
function_names : Optional[List[str]], optional
|
|
264
|
+
Function names to filter by, by default None.
|
|
265
|
+
gpus : Optional[List[int]], optional
|
|
266
|
+
List of GPU counts to filter by, by default None.
|
|
267
|
+
data_sizes : Optional[List[int]], optional
|
|
268
|
+
List of data sizes to filter by, by default None.
|
|
269
|
+
pdims : Optional[List[str]], optional
|
|
270
|
+
List of pdims to filter by, by default None.
|
|
271
|
+
pdims_strategy : List[str], optional
|
|
272
|
+
Strategy for plotting pdims, by default ['plot_fastest'].
|
|
273
|
+
backends : List[str], optional
|
|
274
|
+
List of backends to filter by, by default ['MPI', 'NCCL', 'MPI4JAX'].
|
|
275
|
+
time_columns : List[str], optional
|
|
276
|
+
Time columns to use for aggregation, by default ['mean_time'].
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
Dict[str, pd.DataFrame]
|
|
281
|
+
Dictionary of method names to aggregated dataframes.
|
|
282
|
+
"""
|
|
283
|
+
dataframes = {}
|
|
284
|
+
available_gpu_counts = set()
|
|
285
|
+
available_data_sizes = set()
|
|
286
|
+
for csv_file in csv_files:
|
|
287
|
+
file_name = os.path.splitext(os.path.basename(csv_file))[0]
|
|
288
|
+
ext = os.path.splitext(os.path.basename(csv_file))[1]
|
|
289
|
+
if ext != '.csv':
|
|
290
|
+
print(f"Ignoring {csv_file} as it is not a CSV file")
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
df = pd.read_csv(csv_file,
|
|
294
|
+
header=None,
|
|
295
|
+
skiprows=1,
|
|
296
|
+
names=[
|
|
297
|
+
"function", "precision", "x", "y", "z", "px",
|
|
298
|
+
"py", "backend", "nodes", "jit_time", "min_time",
|
|
299
|
+
"max_time", "mean_time", "std_div", "last_time",
|
|
300
|
+
"generated_code", "argument_size", "output_size",
|
|
301
|
+
"temp_size", "flops"
|
|
302
|
+
],
|
|
303
|
+
dtype={
|
|
304
|
+
"function": str,
|
|
305
|
+
"precision": str,
|
|
306
|
+
"x": int,
|
|
307
|
+
"y": int,
|
|
308
|
+
"z": int,
|
|
309
|
+
"px": int,
|
|
310
|
+
"py": int,
|
|
311
|
+
"backend": str,
|
|
312
|
+
"nodes": int,
|
|
313
|
+
"jit_time": float,
|
|
314
|
+
"min_time": float,
|
|
315
|
+
"max_time": float,
|
|
316
|
+
"mean_time": float,
|
|
317
|
+
"std_div": float,
|
|
318
|
+
"last_time": float,
|
|
319
|
+
"generated_code": float,
|
|
320
|
+
"argument_size": float,
|
|
321
|
+
"output_size": float,
|
|
322
|
+
"temp_size": float,
|
|
323
|
+
"flops": float
|
|
324
|
+
},
|
|
325
|
+
index_col=False)
|
|
326
|
+
|
|
327
|
+
# Filter precisions
|
|
328
|
+
if precisions:
|
|
329
|
+
df = df[df['precision'].isin(precisions)]
|
|
330
|
+
# Filter function names
|
|
331
|
+
if function_names:
|
|
332
|
+
df = df[df['function'].isin(function_names)]
|
|
333
|
+
# Filter backends
|
|
334
|
+
if backends:
|
|
335
|
+
df = df[df['backend'].isin(backends)]
|
|
336
|
+
|
|
337
|
+
# Filter data sizes
|
|
338
|
+
if data_sizes:
|
|
339
|
+
df = df[df['x'].isin(data_sizes)]
|
|
340
|
+
|
|
341
|
+
# Filter pdims
|
|
342
|
+
if pdims:
|
|
343
|
+
px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
|
|
344
|
+
df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
|
|
345
|
+
# convert memory units columns to remquested memory_units
|
|
346
|
+
match memory_units:
|
|
347
|
+
case 'KB':
|
|
348
|
+
factor = 1024
|
|
349
|
+
case 'MB':
|
|
350
|
+
factor = 1024**2
|
|
351
|
+
case 'GB':
|
|
352
|
+
factor = 1024**3
|
|
353
|
+
case 'TB':
|
|
354
|
+
factor = 1024**4
|
|
355
|
+
case _:
|
|
356
|
+
factor = 1
|
|
357
|
+
|
|
358
|
+
df['generated_code'] = df['generated_code'] / factor
|
|
359
|
+
df['argument_size'] = df['argument_size'] / factor
|
|
360
|
+
df['output_size'] = df['output_size'] / factor
|
|
361
|
+
df['temp_size'] = df['temp_size'] / factor
|
|
362
|
+
# in case of the same test is run multiple times, keep the last one
|
|
363
|
+
df = df.drop_duplicates(subset=[
|
|
364
|
+
"function", "precision", "x", "y", "z", "px", "py", "backend",
|
|
365
|
+
"nodes"
|
|
366
|
+
],
|
|
367
|
+
keep='last')
|
|
368
|
+
|
|
369
|
+
df['gpus'] = df['px'] * df['py']
|
|
370
|
+
|
|
371
|
+
if gpus:
|
|
372
|
+
df = df[df['gpus'].isin(gpus)]
|
|
373
|
+
|
|
374
|
+
if 'plot_all' in pdims_strategy or 'slab_yz' in pdims_strategy or 'slab_xy' in pdims_strategy or 'pencils' in pdims_strategy:
|
|
375
|
+
|
|
376
|
+
def get_decomp_from_px_py(row):
|
|
377
|
+
if row['px'] > 1 and row['py'] == 1:
|
|
378
|
+
return 'slab_yz'
|
|
379
|
+
elif row['px'] == 1 and row['py'] > 1:
|
|
380
|
+
return 'slab_xy'
|
|
381
|
+
else:
|
|
382
|
+
return 'pencils'
|
|
383
|
+
|
|
384
|
+
df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
|
|
385
|
+
df.drop(columns=['px', 'py'], inplace=True)
|
|
386
|
+
if not 'plot_all' in pdims_strategy:
|
|
387
|
+
df = df[df['decomp'].isin(pdims_strategy)]
|
|
388
|
+
|
|
389
|
+
# check available gpus in dataset
|
|
390
|
+
available_gpu_counts.update(df['gpus'].unique())
|
|
391
|
+
available_data_sizes.update(df['x'].unique())
|
|
392
|
+
|
|
393
|
+
if dataframes.get(file_name) is None:
|
|
394
|
+
dataframes[file_name] = df
|
|
395
|
+
else:
|
|
396
|
+
dataframes[file_name] = pd.concat([dataframes[file_name], df])
|
|
397
|
+
|
|
398
|
+
print(f"requested GPUS: {gpus} available GPUS: {available_gpu_counts}")
|
|
399
|
+
print(
|
|
400
|
+
f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
available_gpu_counts = (available_gpu_counts if gpus is None else [
|
|
404
|
+
gpu for gpu in gpus if gpu in available_gpu_counts
|
|
405
|
+
])
|
|
406
|
+
available_data_sizes = (available_data_sizes if data_sizes is None else [
|
|
407
|
+
data_size for data_size in data_sizes
|
|
408
|
+
if data_size in available_data_sizes
|
|
409
|
+
])
|
|
410
|
+
|
|
411
|
+
return dataframes, available_gpu_counts, available_data_sizes
|