jax-hpc-profiler 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,274 @@
1
+ import os
2
+ import time
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from jax import make_jaxpr
10
+ from jax.experimental import mesh_utils
11
+ from jax.experimental.shard_map import shard_map
12
+ from jax.sharding import Mesh, NamedSharding
13
+ from jax.sharding import PartitionSpec as P
14
+ from tabulate import tabulate
15
+
16
+
17
+ class Timer:
18
+
19
+ def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
20
+ self.jit_time = 0.0
21
+ self.times = []
22
+ self.profiling_data = {}
23
+ self.compiled_code = {}
24
+ self.save_jaxpr = save_jaxpr
25
+ self.jax_fn = jax_fn
26
+ self.devices = devices
27
+
28
+ def _normalize_memory_units(self, memory_analysis) -> str:
29
+
30
+ if not self.jax_fn:
31
+ return memory_analysis
32
+
33
+ sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
34
+ factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
+ factor = 0 if memory_analysis == 0 else int(
36
+ np.log10(memory_analysis) // 3)
37
+
38
+ return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
39
+
40
+ def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
41
+ if memory_analysis is None:
42
+ return None, None, None, None
43
+ return (
44
+ memory_analysis.generated_code_size_in_bytes,
45
+ memory_analysis.argument_size_in_bytes,
46
+ memory_analysis.output_size_in_bytes,
47
+ memory_analysis.temp_size_in_bytes,
48
+ )
49
+
50
+ def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
51
+ start = time.perf_counter()
52
+ out = fun(*args)
53
+ if self.jax_fn:
54
+ if ndarray_arg is None:
55
+ out.block_until_ready()
56
+ else:
57
+ out[ndarray_arg].block_until_ready()
58
+ end = time.perf_counter()
59
+ self.jit_time = (end - start) * 1e3
60
+
61
+ if self.save_jaxpr:
62
+ jaxpr = make_jaxpr(fun)(*args)
63
+ self.compiled_code["JAXPR"] = jaxpr.pretty_print()
64
+
65
+ if self.jax_fn:
66
+ lowered = jax.jit(fun).lower(*args)
67
+ compiled = lowered.compile()
68
+ memory_analysis = self._read_memory_analysis(
69
+ compiled.memory_analysis())
70
+
71
+ self.compiled_code["LOWERED"] = lowered.as_text()
72
+ self.compiled_code["COMPILED"] = compiled.as_text()
73
+ self.profiling_data["generated_code"] = memory_analysis[0]
74
+ self.profiling_data["argument_size"] = memory_analysis[1]
75
+ self.profiling_data["output_size"] = memory_analysis[2]
76
+ self.profiling_data["temp_size"] = memory_analysis[3]
77
+
78
+ return out
79
+
80
+ def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
81
+ start = time.perf_counter()
82
+ out = fun(*args)
83
+ if self.jax_fn:
84
+ if ndarray_arg is None:
85
+ out.block_until_ready()
86
+ else:
87
+ out[ndarray_arg].block_until_ready()
88
+ end = time.perf_counter()
89
+ self.times.append((end - start) * 1e3)
90
+ return out
91
+
92
+ def _get_mean_times(self) -> np.ndarray:
93
+ if jax.device_count() == 1 or jax.process_count() == 1:
94
+ return np.array(self.times)
95
+
96
+ if self.devices is None:
97
+ self.devices = jax.devices()
98
+
99
+ mesh = jax.make_mesh((len(self.devices), ), ("x", ),
100
+ devices=self.devices)
101
+ sharding = NamedSharding(mesh, P("x"))
102
+
103
+ times_array = jnp.array(self.times)
104
+ global_shape = (jax.device_count(), times_array.shape[0])
105
+ global_times = jax.make_array_from_callback(
106
+ shape=global_shape,
107
+ sharding=sharding,
108
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
109
+ )
110
+
111
+ @partial(shard_map,
112
+ mesh=mesh,
113
+ in_specs=P("x"),
114
+ out_specs=P(),
115
+ check_rep=False)
116
+ def get_mean_times(times):
117
+ return jax.lax.pmean(times, axis_name="x")
118
+
119
+ times_array = get_mean_times(global_times)
120
+ times_array.block_until_ready()
121
+ return np.array(times_array.addressable_data(0)[0])
122
+
123
+ def report(
124
+ self,
125
+ csv_filename: str,
126
+ function: str,
127
+ x: int,
128
+ y: int | None = None,
129
+ z: int | None = None,
130
+ precision: str = "float32",
131
+ px: int = 1,
132
+ py: int = 1,
133
+ backend: str = "NCCL",
134
+ nodes: int = 1,
135
+ md_filename: str | None = None,
136
+ npz_data: Optional[dict] = None,
137
+ extra_info: dict = {},
138
+ ):
139
+ if self.jit_time == 0.0 and len(self.times) == 0:
140
+ print(f"No profiling data to report for {function}")
141
+ return
142
+
143
+ if md_filename is None:
144
+ dirname, filename = (
145
+ os.path.dirname(csv_filename),
146
+ os.path.splitext(os.path.basename(csv_filename))[0],
147
+ )
148
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
149
+ os.makedirs(report_folder, exist_ok=True)
150
+ md_filename = (
151
+ f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
152
+ )
153
+
154
+ if npz_data is not None:
155
+ dirname, filename = (
156
+ os.path.dirname(csv_filename),
157
+ os.path.splitext(os.path.basename(csv_filename))[0],
158
+ )
159
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
160
+ os.makedirs(report_folder, exist_ok=True)
161
+ npz_filename = (
162
+ f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
163
+ )
164
+ np.savez(npz_filename, **npz_data)
165
+
166
+ y = x if y is None else y
167
+ z = x if z is None else z
168
+
169
+ times_array = self._get_mean_times()
170
+ if jax.process_index() == 0:
171
+
172
+ min_time = np.min(times_array)
173
+ max_time = np.max(times_array)
174
+ mean_time = np.mean(times_array)
175
+ std_time = np.std(times_array)
176
+ last_time = times_array[-1]
177
+
178
+ if self.jax_fn:
179
+
180
+ generated_code = self.profiling_data["generated_code"]
181
+ argument_size = self.profiling_data["argument_size"]
182
+ output_size = self.profiling_data["output_size"]
183
+ temp_size = self.profiling_data["temp_size"]
184
+ else:
185
+ generated_code = "N/A"
186
+ argument_size = "N/A"
187
+ output_size = "N/A"
188
+ temp_size = "N/A"
189
+
190
+ csv_line = (
191
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
192
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
193
+ f"{generated_code},{argument_size},{output_size},{temp_size}\n"
194
+ )
195
+
196
+ with open(csv_filename, "a") as f:
197
+ f.write(csv_line)
198
+
199
+ param_dict = {
200
+ "Function": function,
201
+ "Precision": precision,
202
+ "X": x,
203
+ "Y": y,
204
+ "Z": z,
205
+ "PX": px,
206
+ "PY": py,
207
+ "Backend": backend,
208
+ "Nodes": nodes,
209
+ }
210
+ param_dict.update(extra_info)
211
+ profiling_result = {
212
+ "JIT Time": self.jit_time,
213
+ "Min Time": min_time,
214
+ "Max Time": max_time,
215
+ "Mean Time": mean_time,
216
+ "Std Time": std_time,
217
+ "Last Time": last_time,
218
+ "Generated Code": self._normalize_memory_units(generated_code),
219
+ "Argument Size": self._normalize_memory_units(argument_size),
220
+ "Output Size": self._normalize_memory_units(output_size),
221
+ "Temporary Size": self._normalize_memory_units(temp_size),
222
+ }
223
+ iteration_runs = {}
224
+ for i in range(len(times_array)):
225
+ iteration_runs[f"Run {i}"] = times_array[i]
226
+
227
+ with open(md_filename, "w") as f:
228
+ f.write(f"# Reporting for {function}\n")
229
+ f.write(f"## Parameters\n")
230
+ f.write(
231
+ tabulate(
232
+ param_dict.items(),
233
+ headers=["Parameter", "Value"],
234
+ tablefmt="github",
235
+ ))
236
+ f.write("\n---\n")
237
+ f.write(f"## Profiling Data\n")
238
+ f.write(
239
+ tabulate(
240
+ profiling_result.items(),
241
+ headers=["Parameter", "Value"],
242
+ tablefmt="github",
243
+ ))
244
+ f.write("\n---\n")
245
+ f.write(f"## Iteration Runs\n")
246
+ f.write(
247
+ tabulate(
248
+ iteration_runs.items(),
249
+ headers=["Iteration", "Time"],
250
+ tablefmt="github",
251
+ ))
252
+ if self.jax_fn:
253
+ f.write("\n---\n")
254
+ f.write(f"## Compiled Code\n")
255
+ f.write(f"```hlo\n")
256
+ f.write(self.compiled_code["COMPILED"])
257
+ f.write(f"\n```\n")
258
+ f.write("\n---\n")
259
+ f.write(f"## Lowered Code\n")
260
+ f.write(f"```hlo\n")
261
+ f.write(self.compiled_code["LOWERED"])
262
+ f.write(f"\n```\n")
263
+ f.write("\n---\n")
264
+ if self.save_jaxpr:
265
+ f.write(f"## JAXPR\n")
266
+ f.write(f"```haskel\n")
267
+ f.write(self.compiled_code["JAXPR"])
268
+ f.write(f"\n```\n")
269
+
270
+ # Reset the timer
271
+ self.jit_time = 0.0
272
+ self.times = []
273
+ self.profiling_data = {}
274
+ self.compiled_code = {}
@@ -0,0 +1,411 @@
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ from matplotlib.axes import Axes
8
+
9
+
10
+ def inspect_data(dataframes: Dict[str, pd.DataFrame]):
11
+ """
12
+ Inspect the dataframes.
13
+
14
+ Parameters
15
+ ----------
16
+ dataframes : Dict[str, pd.DataFrame]
17
+ Dictionary of method names to dataframes.
18
+ """
19
+ print("=" * 80)
20
+ print("Inspecting dataframes...")
21
+ print("=" * 80)
22
+ for method, df in dataframes.items():
23
+ print(f"Method: {method}")
24
+ inspect_df(df)
25
+ print("=" * 80)
26
+
27
+
28
+ def inspect_df(df: pd.DataFrame):
29
+ """
30
+ Inspect the dataframe.
31
+
32
+ Parameters
33
+ ----------
34
+ df : pd.DataFrame
35
+ The dataframe to inspect.
36
+ """
37
+ print(df.to_markdown())
38
+ print("-" * 80)
39
+
40
+
41
+ params_dict = {
42
+ "%pn%": "%plot_name%",
43
+ "%m%": "%method_name%",
44
+ "%n%": "%node%",
45
+ "%b%": "%backend%",
46
+ "%f%": "%function%",
47
+ "%cn%": "%column_name%",
48
+ "%pr%": "%precision%",
49
+ "%p%": "%decomposition%",
50
+ "%d%": "%data_size%",
51
+ "%g%": "%nb_gpu%"
52
+ }
53
+
54
+
55
+ def expand_label(label_template: str, params: dict) -> str:
56
+ """
57
+ Expand the label template with the provided parameters.
58
+
59
+ Parameters
60
+ ----------
61
+ label_template : str
62
+ The label template with placeholders.
63
+ params : dict
64
+ The dictionary with actual values to replace placeholders.
65
+
66
+ Returns
67
+ -------
68
+ str
69
+ The expanded label.
70
+ """
71
+ for key, value in params_dict.items():
72
+ label_template = label_template.replace(key, value)
73
+
74
+ for key, value in params.items():
75
+ label_template = label_template.replace(f"%{key}%", str(value))
76
+ return label_template
77
+
78
+
79
+ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
80
+ pdims_strategy: List[str],
81
+ print_decompositions: bool, x_col: str,
82
+ y_col: str, label_template: str):
83
+ """
84
+ Plot the data based on the pdims strategy.
85
+
86
+ Parameters
87
+ ----------
88
+ ax : Axes
89
+ The axes to plot on.
90
+ df : pd.DataFrame
91
+ The dataframe to plot.
92
+ method : str
93
+ The method name.
94
+ backend : str
95
+ The backend name.
96
+ nodes_in_label : bool
97
+ Whether to include node names in labels.
98
+ pdims_strategy : List[str]
99
+ Strategy for plotting pdims.
100
+ print_decompositions : bool
101
+ Whether to print decompositions on the plot.
102
+ x_col : str
103
+ The column name for the x-axis values.
104
+ x_label : str
105
+ The label for the x-axis.
106
+ y_label : str
107
+ The label for the y-axis.
108
+ label_template : str
109
+ Template for plot labels with placeholders.
110
+ """
111
+ label_params = {
112
+ "plot_name": y_col,
113
+ "method_name": method,
114
+ "backend": df['backend'].values[0],
115
+ "node": df['nodes'].values[0],
116
+ "precision": df['precision'].values[0],
117
+ "function": df['function'].values[0],
118
+ }
119
+
120
+ if 'plot_fastest' in pdims_strategy:
121
+ df_decomp = df.groupby([x_col])
122
+
123
+ # Sort all and keep fastest
124
+ sorted_dfs = []
125
+ for _, group in df_decomp:
126
+ group.sort_values(by=[y_col], inplace=True, ascending=True)
127
+ sorted_dfs.append(group.iloc[0])
128
+ sorted_df = pd.DataFrame(sorted_dfs)
129
+ label_params.update({
130
+ "decomposition":
131
+ f"{group['px'].values[0]}x{group['py'].values[0]}"
132
+ })
133
+ label = expand_label(label_template, label_params)
134
+ ax.plot(sorted_df[x_col].values,
135
+ sorted_df[y_col],
136
+ marker='o',
137
+ linestyle='-',
138
+ label=label)
139
+ # TODO(wassim) : this is not working very well
140
+ if print_decompositions:
141
+ for j, (px, py) in enumerate(zip(sorted_df['px'],
142
+ sorted_df['py'])):
143
+ ax.annotate(
144
+ f"{px}x{py}",
145
+ (sorted_df[x_col].values[j], sorted_df[y_col].values[j]),
146
+ textcoords="offset points",
147
+ xytext=(0, 10),
148
+ ha='center',
149
+ color='red' if j == 0 else 'white')
150
+ return sorted_df[x_col].values, sorted_df[y_col].values
151
+
152
+ elif any(strategy in pdims_strategy
153
+ for strategy in ['plot_all', 'slab_yz', 'slab_xy', 'pencils']):
154
+ df_decomp = df.groupby(['decomp'])
155
+ x_values = []
156
+ y_values = []
157
+ for _, group in df_decomp:
158
+ group.drop_duplicates(subset=[x_col, 'decomp'],
159
+ keep='last',
160
+ inplace=True)
161
+ group.sort_values(by=[x_col], inplace=True, ascending=False)
162
+ # filter decomp based on pdims_strategy
163
+ if 'plot_all' not in pdims_strategy and group['decomp'].values[
164
+ 0] not in pdims_strategy:
165
+ continue
166
+
167
+ label_params.update({"decomposition": group['decomp'].values[0]})
168
+ label = expand_label(label_template, label_params)
169
+ ax.plot(group[x_col].values,
170
+ group[y_col],
171
+ marker='o',
172
+ linestyle='-',
173
+ label=label)
174
+ x_values.extend(group[x_col].values)
175
+ y_values.extend(group[y_col].values)
176
+ return x_values, y_values
177
+
178
+
179
+ def concatenate_csvs(root_dir: str, output_dir: str):
180
+ """
181
+ Concatenate CSV files and remove duplicates by GPU type.
182
+
183
+ Parameters
184
+ ----------
185
+ root_dir : str
186
+ Root directory containing CSV files.
187
+ output_dir : str
188
+ Output directory to save concatenated CSV files.
189
+ """
190
+ # Iterate over each GPU type directory
191
+ for gpu in os.listdir(root_dir):
192
+ gpu_dir = os.path.join(root_dir, gpu)
193
+
194
+ # Check if the GPU directory exists and is a directory
195
+ if not os.path.isdir(gpu_dir):
196
+ continue
197
+
198
+ # Dictionary to hold combined dataframes for each CSV file name
199
+ combined_dfs = {}
200
+
201
+ # List CSV in directory and subdirectories
202
+ for root, dirs, files in os.walk(gpu_dir):
203
+ for file in files:
204
+ if file.endswith('.csv'):
205
+ csv_file_path = os.path.join(root, file)
206
+ print(f'Concatenating {csv_file_path}...')
207
+ df = pd.read_csv(csv_file_path,
208
+ header=None,
209
+ names=[
210
+ "function", "precision", "x", "y",
211
+ "z", "px", "py", "backend", "nodes",
212
+ "jit_time", "min_time", "max_time",
213
+ "mean_time", "std_div", "last_time",
214
+ "generated_code", "argument_size",
215
+ "output_size", "temp_size", "flops"
216
+ ],
217
+ index_col=False)
218
+ if file not in combined_dfs:
219
+ combined_dfs[file] = df
220
+ else:
221
+ combined_dfs[file] = pd.concat(
222
+ [combined_dfs[file], df], ignore_index=True)
223
+
224
+ # Remove duplicates based on specified columns and save
225
+ for file_name, combined_df in combined_dfs.items():
226
+ combined_df.drop_duplicates(subset=[
227
+ "function", "precision", "x", "y", "z", "px", "py", "backend",
228
+ "nodes"
229
+ ],
230
+ keep='last',
231
+ inplace=True)
232
+
233
+ gpu_output_dir = os.path.join(output_dir, gpu)
234
+ if not os.path.exists(gpu_output_dir):
235
+ print(f"Creating directory {gpu_output_dir}")
236
+ os.makedirs(gpu_output_dir)
237
+
238
+ output_file = os.path.join(gpu_output_dir, file_name)
239
+ print(f"Writing file to {output_file}...")
240
+ combined_df.to_csv(output_file, index=False)
241
+
242
+
243
+ def clean_up_csv(
244
+ csv_files: List[str],
245
+ precisions: Optional[List[str]] = None,
246
+ function_names: Optional[List[str]] = None,
247
+ gpus: Optional[List[int]] = None,
248
+ data_sizes: Optional[List[int]] = None,
249
+ pdims: Optional[List[str]] = None,
250
+ pdims_strategy: List[str] = ['plot_fastest'],
251
+ backends: Optional[List[str]] = None,
252
+ memory_units: str = 'KB',
253
+ ) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
254
+ """
255
+ Clean up and aggregate data from CSV files.
256
+
257
+ Parameters
258
+ ----------
259
+ csv_files : List[str]
260
+ List of CSV files to process.
261
+ precisions : Optional[List[str]], optional
262
+ Precisions to filter by, by default None.
263
+ function_names : Optional[List[str]], optional
264
+ Function names to filter by, by default None.
265
+ gpus : Optional[List[int]], optional
266
+ List of GPU counts to filter by, by default None.
267
+ data_sizes : Optional[List[int]], optional
268
+ List of data sizes to filter by, by default None.
269
+ pdims : Optional[List[str]], optional
270
+ List of pdims to filter by, by default None.
271
+ pdims_strategy : List[str], optional
272
+ Strategy for plotting pdims, by default ['plot_fastest'].
273
+ backends : List[str], optional
274
+ List of backends to filter by, by default ['MPI', 'NCCL', 'MPI4JAX'].
275
+ time_columns : List[str], optional
276
+ Time columns to use for aggregation, by default ['mean_time'].
277
+
278
+ Returns
279
+ -------
280
+ Dict[str, pd.DataFrame]
281
+ Dictionary of method names to aggregated dataframes.
282
+ """
283
+ dataframes = {}
284
+ available_gpu_counts = set()
285
+ available_data_sizes = set()
286
+ for csv_file in csv_files:
287
+ file_name = os.path.splitext(os.path.basename(csv_file))[0]
288
+ ext = os.path.splitext(os.path.basename(csv_file))[1]
289
+ if ext != '.csv':
290
+ print(f"Ignoring {csv_file} as it is not a CSV file")
291
+ continue
292
+
293
+ df = pd.read_csv(csv_file,
294
+ header=None,
295
+ skiprows=1,
296
+ names=[
297
+ "function", "precision", "x", "y", "z", "px",
298
+ "py", "backend", "nodes", "jit_time", "min_time",
299
+ "max_time", "mean_time", "std_div", "last_time",
300
+ "generated_code", "argument_size", "output_size",
301
+ "temp_size", "flops"
302
+ ],
303
+ dtype={
304
+ "function": str,
305
+ "precision": str,
306
+ "x": int,
307
+ "y": int,
308
+ "z": int,
309
+ "px": int,
310
+ "py": int,
311
+ "backend": str,
312
+ "nodes": int,
313
+ "jit_time": float,
314
+ "min_time": float,
315
+ "max_time": float,
316
+ "mean_time": float,
317
+ "std_div": float,
318
+ "last_time": float,
319
+ "generated_code": float,
320
+ "argument_size": float,
321
+ "output_size": float,
322
+ "temp_size": float,
323
+ "flops": float
324
+ },
325
+ index_col=False)
326
+
327
+ # Filter precisions
328
+ if precisions:
329
+ df = df[df['precision'].isin(precisions)]
330
+ # Filter function names
331
+ if function_names:
332
+ df = df[df['function'].isin(function_names)]
333
+ # Filter backends
334
+ if backends:
335
+ df = df[df['backend'].isin(backends)]
336
+
337
+ # Filter data sizes
338
+ if data_sizes:
339
+ df = df[df['x'].isin(data_sizes)]
340
+
341
+ # Filter pdims
342
+ if pdims:
343
+ px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
344
+ df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
345
+ # convert memory units columns to remquested memory_units
346
+ match memory_units:
347
+ case 'KB':
348
+ factor = 1024
349
+ case 'MB':
350
+ factor = 1024**2
351
+ case 'GB':
352
+ factor = 1024**3
353
+ case 'TB':
354
+ factor = 1024**4
355
+ case _:
356
+ factor = 1
357
+
358
+ df['generated_code'] = df['generated_code'] / factor
359
+ df['argument_size'] = df['argument_size'] / factor
360
+ df['output_size'] = df['output_size'] / factor
361
+ df['temp_size'] = df['temp_size'] / factor
362
+ # in case of the same test is run multiple times, keep the last one
363
+ df = df.drop_duplicates(subset=[
364
+ "function", "precision", "x", "y", "z", "px", "py", "backend",
365
+ "nodes"
366
+ ],
367
+ keep='last')
368
+
369
+ df['gpus'] = df['px'] * df['py']
370
+
371
+ if gpus:
372
+ df = df[df['gpus'].isin(gpus)]
373
+
374
+ if 'plot_all' in pdims_strategy or 'slab_yz' in pdims_strategy or 'slab_xy' in pdims_strategy or 'pencils' in pdims_strategy:
375
+
376
+ def get_decomp_from_px_py(row):
377
+ if row['px'] > 1 and row['py'] == 1:
378
+ return 'slab_yz'
379
+ elif row['px'] == 1 and row['py'] > 1:
380
+ return 'slab_xy'
381
+ else:
382
+ return 'pencils'
383
+
384
+ df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
385
+ df.drop(columns=['px', 'py'], inplace=True)
386
+ if not 'plot_all' in pdims_strategy:
387
+ df = df[df['decomp'].isin(pdims_strategy)]
388
+
389
+ # check available gpus in dataset
390
+ available_gpu_counts.update(df['gpus'].unique())
391
+ available_data_sizes.update(df['x'].unique())
392
+
393
+ if dataframes.get(file_name) is None:
394
+ dataframes[file_name] = df
395
+ else:
396
+ dataframes[file_name] = pd.concat([dataframes[file_name], df])
397
+
398
+ print(f"requested GPUS: {gpus} available GPUS: {available_gpu_counts}")
399
+ print(
400
+ f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
401
+ )
402
+
403
+ available_gpu_counts = (available_gpu_counts if gpus is None else [
404
+ gpu for gpu in gpus if gpu in available_gpu_counts
405
+ ])
406
+ available_data_sizes = (available_data_sizes if data_sizes is None else [
407
+ data_size for data_size in data_sizes
408
+ if data_size in available_data_sizes
409
+ ])
410
+
411
+ return dataframes, available_gpu_counts, available_data_sizes