jax-hpc-profiler 0.2.11__tar.gz → 0.2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/pyproject.toml +28 -1
  3. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/__init__.py +7 -2
  4. jax_hpc_profiler-0.2.13/src/jax_hpc_profiler/create_argparse.py +199 -0
  5. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/main.py +15 -19
  6. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/plotting.py +58 -66
  7. jax_hpc_profiler-0.2.13/src/jax_hpc_profiler/timer.py +276 -0
  8. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler/utils.py +191 -132
  9. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  10. jax_hpc_profiler-0.2.11/src/jax_hpc_profiler/create_argparse.py +0 -210
  11. jax_hpc_profiler-0.2.11/src/jax_hpc_profiler/timer.py +0 -289
  12. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/LICENSE +0 -0
  13. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/README.md +0 -0
  14. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/setup.cfg +0 -0
  15. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  16. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  17. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  18. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  19. {jax_hpc_profiler-0.2.11 → jax_hpc_profiler-0.2.13}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.11"
7
+ version = "0.2.13"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -47,3 +47,30 @@ urls = { "Homepage" = "https://github.com/ASKabalan/jax-hpc-profiler" }
47
47
 
48
48
  [project.scripts]
49
49
  jhp = "jax_hpc_profiler.main:main"
50
+
51
+
52
+
53
+ [tool.ruff]
54
+ line-length = 100
55
+ fix = true # autofix issues
56
+ force-exclude = true # useful with ruff-pre-commit plugin
57
+ src = ["src"]
58
+
59
+ [tool.ruff.lint]
60
+ select = [
61
+ 'ARG001', # flake8-unused-function-arguments
62
+ 'E', # pycodestyle-errors
63
+ 'F', # pyflakes
64
+ 'I', # isort
65
+ 'UP', # pyupgrade
66
+ 'T10', # flake8-debugger
67
+ ]
68
+ ignore = [
69
+ 'E203',
70
+ 'E731',
71
+ 'E741',
72
+ 'F722', # conflicts with jaxtyping Array annotations
73
+ ]
74
+
75
+ [tool.ruff.format]
76
+ quote-style = 'single'
@@ -4,6 +4,11 @@ from .timer import Timer
4
4
  from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
5
5
 
6
6
  __all__ = [
7
- 'create_argparser', 'plot_strong_scaling', 'plot_weak_scaling', 'Timer',
8
- 'clean_up_csv', 'concatenate_csvs', 'plot_with_pdims_strategy'
7
+ 'create_argparser',
8
+ 'plot_strong_scaling',
9
+ 'plot_weak_scaling',
10
+ 'Timer',
11
+ 'clean_up_csv',
12
+ 'concatenate_csvs',
13
+ 'plot_with_pdims_strategy',
9
14
  ]
@@ -0,0 +1,199 @@
1
+ import argparse
2
+
3
+
4
+ def create_argparser():
5
+ """
6
+ Create argument parser for the HPC Plotter package.
7
+
8
+ Returns
9
+ -------
10
+ argparse.Namespace
11
+ Parsed and validated arguments.
12
+ """
13
+ parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
14
+
15
+ # Group for concatenation to ensure mutually exclusive behavior
16
+ subparsers = parser.add_subparsers(dest='command', required=True)
17
+
18
+ concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
19
+ concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
20
+ concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
21
+
22
+ # Arguments for plotting
23
+ plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
24
+ plot_parser.add_argument(
25
+ '-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
26
+ )
27
+ plot_parser.add_argument(
28
+ '-g',
29
+ '--gpus',
30
+ nargs='*',
31
+ type=int,
32
+ help='List of number of GPUs to plot',
33
+ default=None,
34
+ )
35
+ plot_parser.add_argument(
36
+ '-d',
37
+ '--data_size',
38
+ nargs='*',
39
+ type=int,
40
+ help='List of data sizes to plot',
41
+ default=None,
42
+ )
43
+
44
+ # pdims related arguments
45
+ plot_parser.add_argument(
46
+ '-fd',
47
+ '--filter_pdims',
48
+ nargs='*',
49
+ help='List of pdims to filter, e.g., 1x4 2x2 4x8',
50
+ default=None,
51
+ )
52
+ plot_parser.add_argument(
53
+ '-ps',
54
+ '--pdim_strategy',
55
+ choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
56
+ nargs='*',
57
+ default=['plot_fastest'],
58
+ help='Strategy for plotting pdims',
59
+ )
60
+
61
+ # Function and precision related arguments
62
+ plot_parser.add_argument(
63
+ '-pr',
64
+ '--precision',
65
+ choices=['float32', 'float64'],
66
+ default=['float32', 'float64'],
67
+ nargs='*',
68
+ help='Precision to filter by (float32 or float64)',
69
+ )
70
+ plot_parser.add_argument(
71
+ '-fn',
72
+ '--function_name',
73
+ nargs='+',
74
+ help='Function names to filter',
75
+ default=None,
76
+ )
77
+
78
+ # Time or memory related arguments
79
+ plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
80
+ plotting_group.add_argument(
81
+ '-pt',
82
+ '--plot_times',
83
+ nargs='*',
84
+ choices=[
85
+ 'jit_time',
86
+ 'min_time',
87
+ 'max_time',
88
+ 'mean_time',
89
+ 'std_time',
90
+ 'last_time',
91
+ ],
92
+ help='Time columns to plot',
93
+ )
94
+ plotting_group.add_argument(
95
+ '-pm',
96
+ '--plot_memory',
97
+ nargs='*',
98
+ choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
99
+ help='Memory columns to plot',
100
+ )
101
+ plot_parser.add_argument(
102
+ '-mu',
103
+ '--memory_units',
104
+ default='GB',
105
+ help='Memory units to plot (KB, MB, GB, TB)',
106
+ )
107
+
108
+ # Plot customization arguments
109
+ plot_parser.add_argument(
110
+ '-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
111
+ )
112
+ plot_parser.add_argument(
113
+ '-o', '--output', help='Output file (if none then only show plot)', default=None
114
+ )
115
+ plot_parser.add_argument(
116
+ '-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
117
+ )
118
+ plot_parser.add_argument(
119
+ '-pd',
120
+ '--print_decompositions',
121
+ action='store_true',
122
+ help='Print decompositions on plot',
123
+ )
124
+
125
+ # Backend related arguments
126
+ plot_parser.add_argument(
127
+ '-b',
128
+ '--backends',
129
+ nargs='*',
130
+ default=['MPI', 'NCCL', 'MPI4JAX'],
131
+ help='List of backends to include',
132
+ )
133
+
134
+ # Scaling type argument
135
+ plot_parser.add_argument(
136
+ '-sc',
137
+ '--scaling',
138
+ choices=['Weak', 'Strong', 'w', 's'],
139
+ required=True,
140
+ help='Scaling type (Weak or Strong)',
141
+ )
142
+
143
+ # Label customization argument
144
+ plot_parser.add_argument(
145
+ '-l',
146
+ '--label_text',
147
+ type=str,
148
+ help=(
149
+ 'Custom label for the plot. You can use placeholders: %%decomposition%% '
150
+ '(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
151
+ '%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
152
+ ),
153
+ default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
154
+ )
155
+
156
+ plot_parser.add_argument(
157
+ '-xl',
158
+ '--xlabel',
159
+ type=str,
160
+ help='X-axis label for the plot',
161
+ )
162
+ plot_parser.add_argument(
163
+ '-tl',
164
+ '--title',
165
+ type=str,
166
+ help='Title for the plot',
167
+ )
168
+
169
+ subparsers.add_parser('label_help', help='Label customization help')
170
+
171
+ args = parser.parse_args()
172
+
173
+ # if command was plot, then check if pdim_strategy is validat
174
+ if args.command == 'plot':
175
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
176
+ print(
177
+ """
178
+ Warning: 'plot_all' strategy is combined with other strategies.
179
+ Using 'plot_all' only.
180
+ """
181
+ )
182
+ args.pdim_strategy = ['plot_all']
183
+
184
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
185
+ print(
186
+ """
187
+ Warning: 'plot_fastest' strategy is combined with other strategies.
188
+ Using 'plot_fastest' only.
189
+ """
190
+ )
191
+ args.pdim_strategy = ['plot_fastest']
192
+ if args.plot_times is not None:
193
+ args.plot_columns = args.plot_times
194
+ elif args.plot_memory is not None:
195
+ args.plot_columns = args.plot_memory
196
+ else:
197
+ raise ValueError('Either plot_times or plot_memory should be provided')
198
+
199
+ return args
@@ -1,29 +1,25 @@
1
- import sys
2
- from typing import List, Optional
3
-
4
1
  from .create_argparse import create_argparser
5
2
  from .plotting import plot_strong_scaling, plot_weak_scaling
6
- from .utils import clean_up_csv, concatenate_csvs
3
+ from .utils import concatenate_csvs
7
4
 
8
5
 
9
6
  def main():
10
7
  args = create_argparser()
11
8
 
12
- if args.command == "concat":
9
+ if args.command == 'concat':
13
10
  input_dir, output_dir = args.input, args.output
14
11
  concatenate_csvs(input_dir, output_dir)
15
- elif args.command == "label_help":
16
- print(f"Customize the label text for the plot. using these commands.")
17
- print(" -- %m% or %methodname%: method name")
18
- print(" -- %f% or %function%: function name")
19
- print(" -- %pn% or %plot_name%: plot name")
20
- print(" -- %pr% or %precision%: precision")
21
- print(" -- %b% or %backend%: backend")
22
- print(" -- %p% or %pdims%: pdims")
23
- print(" -- %n% or %node%: node")
24
- elif args.command == "plot":
25
-
26
- if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
12
+ elif args.command == 'label_help':
13
+ print('Customize the label text for the plot. using these commands.')
14
+ print(' -- %m% or %methodname%: method name')
15
+ print(' -- %f% or %function%: function name')
16
+ print(' -- %pn% or %plot_name%: plot name')
17
+ print(' -- %pr% or %precision%: precision')
18
+ print(' -- %b% or %backend%: backend')
19
+ print(' -- %p% or %pdims%: pdims')
20
+ print(' -- %n% or %node%: node')
21
+ elif args.command == 'plot':
22
+ if args.scaling.lower() == 'weak' or args.scaling.lower() == 'w':
27
23
  plot_weak_scaling(
28
24
  args.csv_files,
29
25
  args.gpus,
@@ -43,7 +39,7 @@ def main():
43
39
  args.dark_bg,
44
40
  args.output,
45
41
  )
46
- elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
42
+ elif args.scaling.lower() == 'strong' or args.scaling.lower() == 's':
47
43
  plot_strong_scaling(
48
44
  args.csv_files,
49
45
  args.gpus,
@@ -65,5 +61,5 @@ def main():
65
61
  )
66
62
 
67
63
 
68
- if __name__ == "__main__":
64
+ if __name__ == '__main__':
69
65
  main()
@@ -4,23 +4,22 @@ from typing import Dict, List, Optional
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import pandas as pd
7
- import seaborn as sns
8
7
  from matplotlib.axes import Axes
9
8
  from matplotlib.patches import FancyBboxPatch
10
9
 
11
- from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
10
+ from .utils import clean_up_csv, plot_with_pdims_strategy
12
11
 
13
- np.seterr(divide="ignore")
12
+ np.seterr(divide='ignore')
14
13
 
15
14
 
16
15
  def configure_axes(
17
16
  ax: Axes,
18
17
  x_values: List[int],
19
18
  y_values: List[float],
20
- title: str,
19
+ title: Optional[str],
21
20
  xlabel: str,
22
21
  plotting_memory: bool = False,
23
- memory_units: str = "bytes",
22
+ memory_units: str = 'bytes',
24
23
  ):
25
24
  """
26
25
  Configure the axes for the plot.
@@ -36,33 +35,32 @@ def configure_axes(
36
35
  xlabel : str
37
36
  The label for the x-axis.
38
37
  """
39
- ylabel = ("Time (milliseconds)"
40
- if not plotting_memory else f"Memory ({memory_units})")
41
- f2 = lambda x: np.log2(x)
42
- g2 = lambda x: 2**x
38
+ ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
39
+
40
+ def f2(x):
41
+ return np.log2(x)
42
+
43
+ def g2(x):
44
+ return 2**x
45
+
43
46
  ax.set_xlim([min(x_values), max(x_values)])
44
47
  y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
45
48
  ax.set_title(title)
46
49
  ax.set_ylim([y_min, y_max])
47
- ax.set_xscale("function", functions=(f2, g2))
50
+ ax.set_xscale('function', functions=(f2, g2))
48
51
  if not plotting_memory:
49
- ax.set_yscale("symlog")
52
+ ax.set_yscale('symlog')
50
53
  time_ticks = [
51
- 10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
52
- int(np.ceil(np.log10(y_max))))
54
+ 10**t for t in range(int(np.floor(np.log10(y_min))), 1 + int(np.ceil(np.log10(y_max))))
53
55
  ]
54
56
  ax.set_yticks(time_ticks)
55
57
  ax.set_xticks(x_values)
56
58
  ax.set_xlabel(xlabel)
57
59
  ax.set_ylabel(ylabel)
58
60
  for x_value in x_values:
59
- ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
61
+ ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
60
62
  ax.legend(
61
- loc="lower center",
62
- bbox_to_anchor=(0.5, 0.05),
63
- ncol=4,
64
- fontsize="x-large",
65
- prop={"size": 14},
63
+ loc='best',
66
64
  )
67
65
 
68
66
 
@@ -80,10 +78,10 @@ def plot_scaling(
80
78
  backends: Optional[List[str]] = None,
81
79
  precisions: Optional[List[str]] = None,
82
80
  functions: Optional[List[str]] = None,
83
- plot_columns: List[str] = ["mean_time"],
84
- memory_units: str = "bytes",
85
- label_text: str = "plot",
86
- pdims_strategy: List[str] = ["plot_fastest"],
81
+ plot_columns: List[str] = ['mean_time'],
82
+ memory_units: str = 'bytes',
83
+ label_text: str = 'plot',
84
+ pdims_strategy: List[str] = ['plot_fastest'],
87
85
  ):
88
86
  """
89
87
  General scaling plot function based on the number of GPUs or data size.
@@ -115,7 +113,7 @@ def plot_scaling(
115
113
  """
116
114
 
117
115
  if dark_bg:
118
- plt.style.use("dark_background")
116
+ plt.style.use('dark_background')
119
117
 
120
118
  num_subplots = len(fixed_sizes)
121
119
  num_rows = int(np.ceil(np.sqrt(num_subplots)))
@@ -133,28 +131,26 @@ def plot_scaling(
133
131
  x_values = []
134
132
  y_values = []
135
133
  for method, df in dataframes.items():
136
-
137
134
  filtered_method_df = df[df[fixed_column] == int(fixed_size)]
138
135
  if filtered_method_df.empty:
139
136
  continue
140
- filtered_method_df = filtered_method_df.sort_values(
141
- by=[size_column])
142
- functions = (pd.unique(filtered_method_df["function"])
143
- if functions is None else functions)
144
- precisions = (pd.unique(filtered_method_df["precision"])
145
- if precisions is None else precisions)
146
- backends = (pd.unique(filtered_method_df["backend"])
147
- if backends is None else backends)
137
+ filtered_method_df = filtered_method_df.sort_values(by=[size_column])
138
+ functions = (
139
+ pd.unique(filtered_method_df['function']) if functions is None else functions
140
+ )
141
+ precisions = (
142
+ pd.unique(filtered_method_df['precision']) if precisions is None else precisions
143
+ )
144
+ backends = pd.unique(filtered_method_df['backend']) if backends is None else backends
148
145
 
149
- combinations = product(backends, precisions, functions,
150
- plot_columns)
146
+ combinations = product(backends, precisions, functions, plot_columns)
151
147
 
152
148
  for backend, precision, function, plot_column in combinations:
153
-
154
149
  filtered_params_df = filtered_method_df[
155
- (filtered_method_df["backend"] == backend)
156
- & (filtered_method_df["precision"] == precision)
157
- & (filtered_method_df["function"] == function)]
150
+ (filtered_method_df['backend'] == backend)
151
+ & (filtered_method_df['precision'] == precision)
152
+ & (filtered_method_df['function'] == function)
153
+ ]
158
154
  if filtered_params_df.empty:
159
155
  continue
160
156
  x_vals, y_vals = plot_with_pdims_strategy(
@@ -172,12 +168,13 @@ def plot_scaling(
172
168
  y_values.extend(y_vals)
173
169
 
174
170
  if len(x_values) != 0:
175
- plotting_memory = "time" not in plot_columns[0].lower()
171
+ plotting_memory = 'time' not in plot_columns[0].lower()
172
+ figure_title = f'{title} {fixed_size}' if title is not None else None
176
173
  configure_axes(
177
174
  ax,
178
175
  x_values,
179
176
  y_values,
180
- f"{title} {fixed_size}",
177
+ figure_title,
181
178
  xlabel,
182
179
  plotting_memory,
183
180
  memory_units,
@@ -187,17 +184,12 @@ def plot_scaling(
187
184
  fig.delaxes(axs[i])
188
185
 
189
186
  fig.tight_layout()
190
- rect = FancyBboxPatch((0.1, 0.1),
191
- 0.8,
192
- 0.8,
193
- boxstyle="round,pad=0.02",
194
- ec="black",
195
- fc="none")
187
+ rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
196
188
  fig.patches.append(rect)
197
189
  if output is None:
198
190
  plt.show()
199
191
  else:
200
- plt.savefig(output, bbox_inches="tight", transparent=True)
192
+ plt.savefig(output)
201
193
 
202
194
 
203
195
  def plot_strong_scaling(
@@ -207,14 +199,14 @@ def plot_strong_scaling(
207
199
  functions: Optional[List[str]] = None,
208
200
  precisions: Optional[List[str]] = None,
209
201
  pdims: Optional[List[str]] = None,
210
- pdims_strategy: List[str] = ["plot_fastest"],
202
+ pdims_strategy: List[str] = ['plot_fastest'],
211
203
  print_decompositions: bool = False,
212
204
  backends: Optional[List[str]] = None,
213
- plot_columns: List[str] = ["mean_time"],
214
- memory_units: str = "bytes",
215
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
216
- xlabel: str = "Number of GPUs",
217
- title: str = "Data sizes",
205
+ plot_columns: List[str] = ['mean_time'],
206
+ memory_units: str = 'bytes',
207
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
208
+ xlabel: str = 'Number of GPUs',
209
+ title: str = 'Data sizes',
218
210
  figure_size: tuple = (6, 4),
219
211
  dark_bg: bool = False,
220
212
  output: Optional[str] = None,
@@ -235,14 +227,14 @@ def plot_strong_scaling(
235
227
  memory_units,
236
228
  )
237
229
  if len(dataframes) == 0:
238
- print(f"No dataframes found for the given arguments. Exiting...")
230
+ print('No dataframes found for the given arguments. Exiting...')
239
231
  return
240
232
 
241
233
  plot_scaling(
242
234
  dataframes,
243
235
  available_data_sizes,
244
- "gpus",
245
- "x",
236
+ 'gpus',
237
+ 'x',
246
238
  xlabel,
247
239
  title,
248
240
  figure_size,
@@ -266,14 +258,14 @@ def plot_weak_scaling(
266
258
  functions: Optional[List[str]] = None,
267
259
  precisions: Optional[List[str]] = None,
268
260
  pdims: Optional[List[str]] = None,
269
- pdims_strategy: List[str] = ["plot_fastest"],
261
+ pdims_strategy: List[str] = ['plot_fastest'],
270
262
  print_decompositions: bool = False,
271
263
  backends: Optional[List[str]] = None,
272
- plot_columns: List[str] = ["mean_time"],
273
- memory_units: str = "bytes",
274
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
275
- xlabel: str = "Data sizes",
276
- title: str = "Number of GPUs",
264
+ plot_columns: List[str] = ['mean_time'],
265
+ memory_units: str = 'bytes',
266
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
267
+ xlabel: str = 'Data sizes',
268
+ title: str = 'Number of GPUs',
277
269
  figure_size: tuple = (6, 4),
278
270
  dark_bg: bool = False,
279
271
  output: Optional[str] = None,
@@ -293,14 +285,14 @@ def plot_weak_scaling(
293
285
  memory_units,
294
286
  )
295
287
  if len(dataframes) == 0:
296
- print(f"No dataframes found for the given arguments. Exiting...")
288
+ print('No dataframes found for the given arguments. Exiting...')
297
289
  return
298
290
 
299
291
  plot_scaling(
300
292
  dataframes,
301
293
  available_gpu_counts,
302
- "x",
303
- "gpus",
294
+ 'x',
295
+ 'gpus',
304
296
  xlabel,
305
297
  title,
306
298
  figure_size,