jax-hpc-profiler 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,11 @@ from .timer import Timer
4
4
  from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
5
5
 
6
6
  __all__ = [
7
- 'create_argparser', 'plot_strong_scaling', 'plot_weak_scaling', 'Timer',
8
- 'clean_up_csv', 'concatenate_csvs', 'plot_with_pdims_strategy'
7
+ 'create_argparser',
8
+ 'plot_strong_scaling',
9
+ 'plot_weak_scaling',
10
+ 'Timer',
11
+ 'clean_up_csv',
12
+ 'concatenate_csvs',
13
+ 'plot_with_pdims_strategy',
9
14
  ]
@@ -10,201 +10,190 @@ def create_argparser():
10
10
  argparse.Namespace
11
11
  Parsed and validated arguments.
12
12
  """
13
- parser = argparse.ArgumentParser(
14
- description="HPC Plotter for benchmarking data")
13
+ parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
15
14
 
16
15
  # Group for concatenation to ensure mutually exclusive behavior
17
- subparsers = parser.add_subparsers(dest="command", required=True)
16
+ subparsers = parser.add_subparsers(dest='command', required=True)
18
17
 
19
- concat_parser = subparsers.add_parser("concat",
20
- help="Concatenate CSV files")
21
- concat_parser.add_argument("input",
22
- type=str,
23
- help="Input directory for concatenation")
24
- concat_parser.add_argument("output",
25
- type=str,
26
- help="Output directory for concatenation")
18
+ concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
19
+ concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
20
+ concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
27
21
 
28
22
  # Arguments for plotting
29
- plot_parser = subparsers.add_parser("plot", help="Plot CSV data")
30
- plot_parser.add_argument("-f",
31
- "--csv_files",
32
- nargs="+",
33
- help="List of CSV files to plot",
34
- required=True)
35
- plot_parser.add_argument(
36
- "-g",
37
- "--gpus",
38
- nargs="*",
23
+ plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
24
+ plot_parser.add_argument(
25
+ '-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
26
+ )
27
+ plot_parser.add_argument(
28
+ '-g',
29
+ '--gpus',
30
+ nargs='*',
39
31
  type=int,
40
- help="List of number of GPUs to plot",
32
+ help='List of number of GPUs to plot',
41
33
  default=None,
42
34
  )
43
35
  plot_parser.add_argument(
44
- "-d",
45
- "--data_size",
46
- nargs="*",
36
+ '-d',
37
+ '--data_size',
38
+ nargs='*',
47
39
  type=int,
48
- help="List of data sizes to plot",
40
+ help='List of data sizes to plot',
49
41
  default=None,
50
42
  )
51
43
 
52
44
  # pdims related arguments
53
45
  plot_parser.add_argument(
54
- "-fd",
55
- "--filter_pdims",
56
- nargs="*",
57
- help="List of pdims to filter, e.g., 1x4 2x2 4x8",
46
+ '-fd',
47
+ '--filter_pdims',
48
+ nargs='*',
49
+ help='List of pdims to filter, e.g., 1x4 2x2 4x8',
58
50
  default=None,
59
51
  )
60
52
  plot_parser.add_argument(
61
- "-ps",
62
- "--pdim_strategy",
63
- choices=["plot_all", "plot_fastest", "slab_yz", "slab_xy", "pencils"],
64
- nargs="*",
65
- default=["plot_fastest"],
66
- help="Strategy for plotting pdims",
53
+ '-ps',
54
+ '--pdim_strategy',
55
+ choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
56
+ nargs='*',
57
+ default=['plot_fastest'],
58
+ help='Strategy for plotting pdims',
67
59
  )
68
60
 
69
61
  # Function and precision related arguments
70
62
  plot_parser.add_argument(
71
- "-pr",
72
- "--precision",
73
- choices=["float32", "float64"],
74
- default=["float32", "float64"],
75
- nargs="*",
76
- help="Precision to filter by (float32 or float64)",
63
+ '-pr',
64
+ '--precision',
65
+ choices=['float32', 'float64'],
66
+ default=['float32', 'float64'],
67
+ nargs='*',
68
+ help='Precision to filter by (float32 or float64)',
77
69
  )
78
70
  plot_parser.add_argument(
79
- "-fn",
80
- "--function_name",
81
- nargs="+",
82
- help="Function names to filter",
71
+ '-fn',
72
+ '--function_name',
73
+ nargs='+',
74
+ help='Function names to filter',
83
75
  default=None,
84
76
  )
85
77
 
86
78
  # Time or memory related arguments
87
79
  plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
88
80
  plotting_group.add_argument(
89
- "-pt",
90
- "--plot_times",
91
- nargs="*",
81
+ '-pt',
82
+ '--plot_times',
83
+ nargs='*',
92
84
  choices=[
93
- "jit_time",
94
- "min_time",
95
- "max_time",
96
- "mean_time",
97
- "std_time",
98
- "last_time",
85
+ 'jit_time',
86
+ 'min_time',
87
+ 'max_time',
88
+ 'mean_time',
89
+ 'std_time',
90
+ 'last_time',
99
91
  ],
100
- help="Time columns to plot",
92
+ help='Time columns to plot',
101
93
  )
102
94
  plotting_group.add_argument(
103
- "-pm",
104
- "--plot_memory",
105
- nargs="*",
106
- choices=[
107
- "generated_code", "argument_size", "output_size", "temp_size"
108
- ],
109
- help="Memory columns to plot",
95
+ '-pm',
96
+ '--plot_memory',
97
+ nargs='*',
98
+ choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
99
+ help='Memory columns to plot',
110
100
  )
111
101
  plot_parser.add_argument(
112
- "-mu",
113
- "--memory_units",
114
- default="GB",
115
- help="Memory units to plot (KB, MB, GB, TB)",
102
+ '-mu',
103
+ '--memory_units',
104
+ default='GB',
105
+ help='Memory units to plot (KB, MB, GB, TB)',
116
106
  )
117
107
 
118
108
  # Plot customization arguments
119
- plot_parser.add_argument("-fs",
120
- "--figure_size",
121
- nargs=2,
122
- type=int,
123
- help="Figure size",
124
- default=(10, 6))
125
- plot_parser.add_argument("-o",
126
- "--output",
127
- help="Output file (if none then only show plot)",
128
- default=None)
129
- plot_parser.add_argument("-db",
130
- "--dark_bg",
131
- action="store_true",
132
- help="Use dark background for plotting")
133
- plot_parser.add_argument(
134
- "-pd",
135
- "--print_decompositions",
136
- action="store_true",
137
- help="Print decompositions on plot",
109
+ plot_parser.add_argument(
110
+ '-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
111
+ )
112
+ plot_parser.add_argument(
113
+ '-o', '--output', help='Output file (if none then only show plot)', default=None
114
+ )
115
+ plot_parser.add_argument(
116
+ '-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
117
+ )
118
+ plot_parser.add_argument(
119
+ '-pd',
120
+ '--print_decompositions',
121
+ action='store_true',
122
+ help='Print decompositions on plot',
138
123
  )
139
124
 
140
125
  # Backend related arguments
141
126
  plot_parser.add_argument(
142
- "-b",
143
- "--backends",
144
- nargs="*",
145
- default=["MPI", "NCCL", "MPI4JAX"],
146
- help="List of backends to include",
127
+ '-b',
128
+ '--backends',
129
+ nargs='*',
130
+ default=['MPI', 'NCCL', 'MPI4JAX'],
131
+ help='List of backends to include',
147
132
  )
148
133
 
149
134
  # Scaling type argument
150
135
  plot_parser.add_argument(
151
- "-sc",
152
- "--scaling",
153
- choices=["Weak", "Strong", "w", "s"],
136
+ '-sc',
137
+ '--scaling',
138
+ choices=['Weak', 'Strong', 'w', 's'],
154
139
  required=True,
155
- help="Scaling type (Weak or Strong)",
140
+ help='Scaling type (Weak or Strong)',
156
141
  )
157
142
 
158
143
  # Label customization argument
159
144
  plot_parser.add_argument(
160
- "-l",
161
- "--label_text",
145
+ '-l',
146
+ '--label_text',
162
147
  type=str,
163
- help=
164
- ("Custom label for the plot. You can use placeholders: %%decomposition%% "
165
- "(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), "
166
- "%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)"
167
- ),
168
- default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
148
+ help=(
149
+ 'Custom label for the plot. You can use placeholders: %%decomposition%% '
150
+ '(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
151
+ '%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
152
+ ),
153
+ default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
169
154
  )
170
155
 
171
156
  plot_parser.add_argument(
172
- "-xl",
173
- "--xlabel",
157
+ '-xl',
158
+ '--xlabel',
174
159
  type=str,
175
- help="X-axis label for the plot",
160
+ help='X-axis label for the plot',
176
161
  )
177
162
  plot_parser.add_argument(
178
- "-tl",
179
- "--title",
163
+ '-tl',
164
+ '--title',
180
165
  type=str,
181
- help="Title for the plot",
166
+ help='Title for the plot',
182
167
  )
183
168
 
184
- subparsers.add_parser("label_help", help="Label customization help")
169
+ subparsers.add_parser('label_help', help='Label customization help')
185
170
 
186
171
  args = parser.parse_args()
187
172
 
188
173
  # if command was plot, then check if pdim_strategy is validat
189
- if args.command == "plot":
190
- if "plot_all" in args.pdim_strategy and len(args.pdim_strategy) > 1:
174
+ if args.command == 'plot':
175
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
191
176
  print(
192
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
177
+ """
178
+ Warning: 'plot_all' strategy is combined with other strategies.
179
+ Using 'plot_all' only.
180
+ """
193
181
  )
194
- args.pdim_strategy = ["plot_all"]
182
+ args.pdim_strategy = ['plot_all']
195
183
 
196
- if "plot_fastest" in args.pdim_strategy and len(
197
- args.pdim_strategy) > 1:
184
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
198
185
  print(
199
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
186
+ """
187
+ Warning: 'plot_fastest' strategy is combined with other strategies.
188
+ Using 'plot_fastest' only.
189
+ """
200
190
  )
201
- args.pdim_strategy = ["plot_fastest"]
191
+ args.pdim_strategy = ['plot_fastest']
202
192
  if args.plot_times is not None:
203
193
  args.plot_columns = args.plot_times
204
194
  elif args.plot_memory is not None:
205
195
  args.plot_columns = args.plot_memory
206
196
  else:
207
- raise ValueError(
208
- "Either plot_times or plot_memory should be provided")
197
+ raise ValueError('Either plot_times or plot_memory should be provided')
209
198
 
210
199
  return args
jax_hpc_profiler/main.py CHANGED
@@ -1,29 +1,25 @@
1
- import sys
2
- from typing import List, Optional
3
-
4
1
  from .create_argparse import create_argparser
5
2
  from .plotting import plot_strong_scaling, plot_weak_scaling
6
- from .utils import clean_up_csv, concatenate_csvs
3
+ from .utils import concatenate_csvs
7
4
 
8
5
 
9
6
  def main():
10
7
  args = create_argparser()
11
8
 
12
- if args.command == "concat":
9
+ if args.command == 'concat':
13
10
  input_dir, output_dir = args.input, args.output
14
11
  concatenate_csvs(input_dir, output_dir)
15
- elif args.command == "label_help":
16
- print(f"Customize the label text for the plot. using these commands.")
17
- print(" -- %m% or %methodname%: method name")
18
- print(" -- %f% or %function%: function name")
19
- print(" -- %pn% or %plot_name%: plot name")
20
- print(" -- %pr% or %precision%: precision")
21
- print(" -- %b% or %backend%: backend")
22
- print(" -- %p% or %pdims%: pdims")
23
- print(" -- %n% or %node%: node")
24
- elif args.command == "plot":
25
-
26
- if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
12
+ elif args.command == 'label_help':
13
+ print('Customize the label text for the plot. using these commands.')
14
+ print(' -- %m% or %methodname%: method name')
15
+ print(' -- %f% or %function%: function name')
16
+ print(' -- %pn% or %plot_name%: plot name')
17
+ print(' -- %pr% or %precision%: precision')
18
+ print(' -- %b% or %backend%: backend')
19
+ print(' -- %p% or %pdims%: pdims')
20
+ print(' -- %n% or %node%: node')
21
+ elif args.command == 'plot':
22
+ if args.scaling.lower() == 'weak' or args.scaling.lower() == 'w':
27
23
  plot_weak_scaling(
28
24
  args.csv_files,
29
25
  args.gpus,
@@ -43,7 +39,7 @@ def main():
43
39
  args.dark_bg,
44
40
  args.output,
45
41
  )
46
- elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
42
+ elif args.scaling.lower() == 'strong' or args.scaling.lower() == 's':
47
43
  plot_strong_scaling(
48
44
  args.csv_files,
49
45
  args.gpus,
@@ -65,5 +61,5 @@ def main():
65
61
  )
66
62
 
67
63
 
68
- if __name__ == "__main__":
64
+ if __name__ == '__main__':
69
65
  main()
@@ -4,23 +4,22 @@ from typing import Dict, List, Optional
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import pandas as pd
7
- import seaborn as sns
8
7
  from matplotlib.axes import Axes
9
8
  from matplotlib.patches import FancyBboxPatch
10
9
 
11
- from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
10
+ from .utils import clean_up_csv, plot_with_pdims_strategy
12
11
 
13
- np.seterr(divide="ignore")
12
+ np.seterr(divide='ignore')
14
13
 
15
14
 
16
15
  def configure_axes(
17
16
  ax: Axes,
18
17
  x_values: List[int],
19
18
  y_values: List[float],
20
- title: str,
19
+ title: Optional[str],
21
20
  xlabel: str,
22
21
  plotting_memory: bool = False,
23
- memory_units: str = "bytes",
22
+ memory_units: str = 'bytes',
24
23
  ):
25
24
  """
26
25
  Configure the axes for the plot.
@@ -36,33 +35,32 @@ def configure_axes(
36
35
  xlabel : str
37
36
  The label for the x-axis.
38
37
  """
39
- ylabel = ("Time (milliseconds)"
40
- if not plotting_memory else f"Memory ({memory_units})")
41
- f2 = lambda x: np.log2(x)
42
- g2 = lambda x: 2**x
38
+ ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
39
+
40
+ def f2(x):
41
+ return np.log2(x)
42
+
43
+ def g2(x):
44
+ return 2**x
45
+
43
46
  ax.set_xlim([min(x_values), max(x_values)])
44
47
  y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
45
48
  ax.set_title(title)
46
49
  ax.set_ylim([y_min, y_max])
47
- ax.set_xscale("function", functions=(f2, g2))
50
+ ax.set_xscale('function', functions=(f2, g2))
48
51
  if not plotting_memory:
49
- ax.set_yscale("symlog")
52
+ ax.set_yscale('symlog')
50
53
  time_ticks = [
51
- 10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
52
- int(np.ceil(np.log10(y_max))))
54
+ 10**t for t in range(int(np.floor(np.log10(y_min))), 1 + int(np.ceil(np.log10(y_max))))
53
55
  ]
54
56
  ax.set_yticks(time_ticks)
55
57
  ax.set_xticks(x_values)
56
58
  ax.set_xlabel(xlabel)
57
59
  ax.set_ylabel(ylabel)
58
60
  for x_value in x_values:
59
- ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
61
+ ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
60
62
  ax.legend(
61
- loc="lower center",
62
- bbox_to_anchor=(0.5, 0.05),
63
- ncol=4,
64
- fontsize="x-large",
65
- prop={"size": 14},
63
+ loc='best',
66
64
  )
67
65
 
68
66
 
@@ -80,10 +78,10 @@ def plot_scaling(
80
78
  backends: Optional[List[str]] = None,
81
79
  precisions: Optional[List[str]] = None,
82
80
  functions: Optional[List[str]] = None,
83
- plot_columns: List[str] = ["mean_time"],
84
- memory_units: str = "bytes",
85
- label_text: str = "plot",
86
- pdims_strategy: List[str] = ["plot_fastest"],
81
+ plot_columns: List[str] = ['mean_time'],
82
+ memory_units: str = 'bytes',
83
+ label_text: str = 'plot',
84
+ pdims_strategy: List[str] = ['plot_fastest'],
87
85
  ):
88
86
  """
89
87
  General scaling plot function based on the number of GPUs or data size.
@@ -115,7 +113,7 @@ def plot_scaling(
115
113
  """
116
114
 
117
115
  if dark_bg:
118
- plt.style.use("dark_background")
116
+ plt.style.use('dark_background')
119
117
 
120
118
  num_subplots = len(fixed_sizes)
121
119
  num_rows = int(np.ceil(np.sqrt(num_subplots)))
@@ -133,28 +131,26 @@ def plot_scaling(
133
131
  x_values = []
134
132
  y_values = []
135
133
  for method, df in dataframes.items():
136
-
137
134
  filtered_method_df = df[df[fixed_column] == int(fixed_size)]
138
135
  if filtered_method_df.empty:
139
136
  continue
140
- filtered_method_df = filtered_method_df.sort_values(
141
- by=[size_column])
142
- functions = (pd.unique(filtered_method_df["function"])
143
- if functions is None else functions)
144
- precisions = (pd.unique(filtered_method_df["precision"])
145
- if precisions is None else precisions)
146
- backends = (pd.unique(filtered_method_df["backend"])
147
- if backends is None else backends)
137
+ filtered_method_df = filtered_method_df.sort_values(by=[size_column])
138
+ functions = (
139
+ pd.unique(filtered_method_df['function']) if functions is None else functions
140
+ )
141
+ precisions = (
142
+ pd.unique(filtered_method_df['precision']) if precisions is None else precisions
143
+ )
144
+ backends = pd.unique(filtered_method_df['backend']) if backends is None else backends
148
145
 
149
- combinations = product(backends, precisions, functions,
150
- plot_columns)
146
+ combinations = product(backends, precisions, functions, plot_columns)
151
147
 
152
148
  for backend, precision, function, plot_column in combinations:
153
-
154
149
  filtered_params_df = filtered_method_df[
155
- (filtered_method_df["backend"] == backend)
156
- & (filtered_method_df["precision"] == precision)
157
- & (filtered_method_df["function"] == function)]
150
+ (filtered_method_df['backend'] == backend)
151
+ & (filtered_method_df['precision'] == precision)
152
+ & (filtered_method_df['function'] == function)
153
+ ]
158
154
  if filtered_params_df.empty:
159
155
  continue
160
156
  x_vals, y_vals = plot_with_pdims_strategy(
@@ -172,12 +168,13 @@ def plot_scaling(
172
168
  y_values.extend(y_vals)
173
169
 
174
170
  if len(x_values) != 0:
175
- plotting_memory = "time" not in plot_columns[0].lower()
171
+ plotting_memory = 'time' not in plot_columns[0].lower()
172
+ figure_title = f'{title} {fixed_size}' if title is not None else None
176
173
  configure_axes(
177
174
  ax,
178
175
  x_values,
179
176
  y_values,
180
- f"{title} {fixed_size}",
177
+ figure_title,
181
178
  xlabel,
182
179
  plotting_memory,
183
180
  memory_units,
@@ -187,17 +184,12 @@ def plot_scaling(
187
184
  fig.delaxes(axs[i])
188
185
 
189
186
  fig.tight_layout()
190
- rect = FancyBboxPatch((0.1, 0.1),
191
- 0.8,
192
- 0.8,
193
- boxstyle="round,pad=0.02",
194
- ec="black",
195
- fc="none")
187
+ rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
196
188
  fig.patches.append(rect)
197
189
  if output is None:
198
190
  plt.show()
199
191
  else:
200
- plt.savefig(output, bbox_inches="tight", transparent=True)
192
+ plt.savefig(output)
201
193
 
202
194
 
203
195
  def plot_strong_scaling(
@@ -207,14 +199,14 @@ def plot_strong_scaling(
207
199
  functions: Optional[List[str]] = None,
208
200
  precisions: Optional[List[str]] = None,
209
201
  pdims: Optional[List[str]] = None,
210
- pdims_strategy: List[str] = ["plot_fastest"],
202
+ pdims_strategy: List[str] = ['plot_fastest'],
211
203
  print_decompositions: bool = False,
212
204
  backends: Optional[List[str]] = None,
213
- plot_columns: List[str] = ["mean_time"],
214
- memory_units: str = "bytes",
215
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
216
- xlabel: str = "Number of GPUs",
217
- title: str = "Data sizes",
205
+ plot_columns: List[str] = ['mean_time'],
206
+ memory_units: str = 'bytes',
207
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
208
+ xlabel: str = 'Number of GPUs',
209
+ title: str = 'Data sizes',
218
210
  figure_size: tuple = (6, 4),
219
211
  dark_bg: bool = False,
220
212
  output: Optional[str] = None,
@@ -235,14 +227,14 @@ def plot_strong_scaling(
235
227
  memory_units,
236
228
  )
237
229
  if len(dataframes) == 0:
238
- print(f"No dataframes found for the given arguments. Exiting...")
230
+ print('No dataframes found for the given arguments. Exiting...')
239
231
  return
240
232
 
241
233
  plot_scaling(
242
234
  dataframes,
243
235
  available_data_sizes,
244
- "gpus",
245
- "x",
236
+ 'gpus',
237
+ 'x',
246
238
  xlabel,
247
239
  title,
248
240
  figure_size,
@@ -266,14 +258,14 @@ def plot_weak_scaling(
266
258
  functions: Optional[List[str]] = None,
267
259
  precisions: Optional[List[str]] = None,
268
260
  pdims: Optional[List[str]] = None,
269
- pdims_strategy: List[str] = ["plot_fastest"],
261
+ pdims_strategy: List[str] = ['plot_fastest'],
270
262
  print_decompositions: bool = False,
271
263
  backends: Optional[List[str]] = None,
272
- plot_columns: List[str] = ["mean_time"],
273
- memory_units: str = "bytes",
274
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
275
- xlabel: str = "Data sizes",
276
- title: str = "Number of GPUs",
264
+ plot_columns: List[str] = ['mean_time'],
265
+ memory_units: str = 'bytes',
266
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
267
+ xlabel: str = 'Data sizes',
268
+ title: str = 'Number of GPUs',
277
269
  figure_size: tuple = (6, 4),
278
270
  dark_bg: bool = False,
279
271
  output: Optional[str] = None,
@@ -293,14 +285,14 @@ def plot_weak_scaling(
293
285
  memory_units,
294
286
  )
295
287
  if len(dataframes) == 0:
296
- print(f"No dataframes found for the given arguments. Exiting...")
288
+ print('No dataframes found for the given arguments. Exiting...')
297
289
  return
298
290
 
299
291
  plot_scaling(
300
292
  dataframes,
301
293
  available_gpu_counts,
302
- "x",
303
- "gpus",
294
+ 'x',
295
+ 'gpus',
304
296
  xlabel,
305
297
  title,
306
298
  figure_size,