jax-hpc-profiler 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,151 +11,187 @@ def create_argparser():
11
11
  Parsed and validated arguments.
12
12
  """
13
13
  parser = argparse.ArgumentParser(
14
- description='HPC Plotter for benchmarking data')
14
+ description="HPC Plotter for benchmarking data")
15
15
 
16
16
  # Group for concatenation to ensure mutually exclusive behavior
17
- subparsers = parser.add_subparsers(dest='command', required=True)
17
+ subparsers = parser.add_subparsers(dest="command", required=True)
18
18
 
19
- concat_parser = subparsers.add_parser('concat',
20
- help='Concatenate CSV files')
21
- concat_parser.add_argument('input',
19
+ concat_parser = subparsers.add_parser("concat",
20
+ help="Concatenate CSV files")
21
+ concat_parser.add_argument("input",
22
22
  type=str,
23
- help='Input directory for concatenation')
24
- concat_parser.add_argument('output',
23
+ help="Input directory for concatenation")
24
+ concat_parser.add_argument("output",
25
25
  type=str,
26
- help='Output directory for concatenation')
26
+ help="Output directory for concatenation")
27
27
 
28
28
  # Arguments for plotting
29
- plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
30
- plot_parser.add_argument('-f',
31
- '--csv_files',
32
- nargs='+',
33
- help='List of CSV files to plot',
29
+ plot_parser = subparsers.add_parser("plot", help="Plot CSV data")
30
+ plot_parser.add_argument("-f",
31
+ "--csv_files",
32
+ nargs="+",
33
+ help="List of CSV files to plot",
34
34
  required=True)
35
- plot_parser.add_argument('-g',
36
- '--gpus',
37
- nargs='*',
38
- type=int,
39
- help='List of number of GPUs to plot')
40
- plot_parser.add_argument('-d',
41
- '--data_size',
42
- nargs='*',
43
- type=int,
44
- help='List of data sizes to plot')
35
+ plot_parser.add_argument(
36
+ "-g",
37
+ "--gpus",
38
+ nargs="*",
39
+ type=int,
40
+ help="List of number of GPUs to plot",
41
+ default=None,
42
+ )
43
+ plot_parser.add_argument(
44
+ "-d",
45
+ "--data_size",
46
+ nargs="*",
47
+ type=int,
48
+ help="List of data sizes to plot",
49
+ default=None,
50
+ )
45
51
 
46
52
  # pdims related arguments
47
- plot_parser.add_argument('-fd',
48
- '--filter_pdims',
49
- nargs='*',
50
- help='List of pdims to filter, e.g., 1x4 2x2 4x8')
51
53
  plot_parser.add_argument(
52
- '-ps',
53
- '--pdim_strategy',
54
- choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
55
- nargs='*',
56
- default=['plot_fastest'],
57
- help='Strategy for plotting pdims')
54
+ "-fd",
55
+ "--filter_pdims",
56
+ nargs="*",
57
+ help="List of pdims to filter, e.g., 1x4 2x2 4x8",
58
+ default=None,
59
+ )
60
+ plot_parser.add_argument(
61
+ "-ps",
62
+ "--pdim_strategy",
63
+ choices=["plot_all", "plot_fastest", "slab_yz", "slab_xy", "pencils"],
64
+ nargs="*",
65
+ default=["plot_fastest"],
66
+ help="Strategy for plotting pdims",
67
+ )
58
68
 
59
69
  # Function and precision related arguments
60
70
  plot_parser.add_argument(
61
- '-pr',
62
- '--precision',
63
- choices=['float32', 'float64'],
64
- default=['float32', 'float64'],
65
- nargs='*',
66
- help='Precision to filter by (float32 or float64)')
67
- plot_parser.add_argument('-fn',
68
- '--function_name',
69
- nargs='+',
70
- help='Function names to filter')
71
+ "-pr",
72
+ "--precision",
73
+ choices=["float32", "float64"],
74
+ default=["float32", "float64"],
75
+ nargs="*",
76
+ help="Precision to filter by (float32 or float64)",
77
+ )
78
+ plot_parser.add_argument(
79
+ "-fn",
80
+ "--function_name",
81
+ nargs="+",
82
+ help="Function names to filter",
83
+ default=None,
84
+ )
71
85
 
72
86
  # Time or memory related arguments
73
87
  plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
74
- plotting_group.add_argument('-pt',
75
- '--plot_times',
76
- nargs='*',
77
- choices=[
78
- 'jit_time', 'min_time', 'max_time',
79
- 'mean_time', 'std_time', 'last_time'
80
- ],
81
- help='Time columns to plot')
82
- plotting_group.add_argument('-pm',
83
- '--plot_memory',
84
- nargs='*',
85
- choices=[
86
- 'generated_code', 'argument_size',
87
- 'output_size', 'temp_size'
88
- ],
89
- help='Memory columns to plot')
90
- plot_parser.add_argument('-mu',
91
- '--memory_units',
92
- default='GB',
93
- help='Memory units to plot (KB, MB, GB, TB)')
88
+ plotting_group.add_argument(
89
+ "-pt",
90
+ "--plot_times",
91
+ nargs="*",
92
+ choices=[
93
+ "jit_time",
94
+ "min_time",
95
+ "max_time",
96
+ "mean_time",
97
+ "std_time",
98
+ "last_time",
99
+ ],
100
+ help="Time columns to plot",
101
+ )
102
+ plotting_group.add_argument(
103
+ "-pm",
104
+ "--plot_memory",
105
+ nargs="*",
106
+ choices=[
107
+ "generated_code", "argument_size", "output_size", "temp_size"
108
+ ],
109
+ help="Memory columns to plot",
110
+ )
111
+ plot_parser.add_argument(
112
+ "-mu",
113
+ "--memory_units",
114
+ default="GB",
115
+ help="Memory units to plot (KB, MB, GB, TB)",
116
+ )
94
117
 
95
118
  # Plot customization arguments
96
- plot_parser.add_argument('-fs',
97
- '--figure_size',
119
+ plot_parser.add_argument("-fs",
120
+ "--figure_size",
98
121
  nargs=2,
99
122
  type=int,
100
- help='Figure size')
101
- plot_parser.add_argument('-o',
102
- '--output',
103
- help='Output file (if none then only show plot)',
123
+ help="Figure size",
124
+ default=(10, 6))
125
+ plot_parser.add_argument("-o",
126
+ "--output",
127
+ help="Output file (if none then only show plot)",
104
128
  default=None)
105
- plot_parser.add_argument('-db',
106
- '--dark_bg',
107
- action='store_true',
108
- help='Use dark background for plotting')
109
- plot_parser.add_argument('-pd',
110
- '--print_decompositions',
111
- action='store_true',
112
- help='Print decompositions on plot')
129
+ plot_parser.add_argument("-db",
130
+ "--dark_bg",
131
+ action="store_true",
132
+ help="Use dark background for plotting")
133
+ plot_parser.add_argument(
134
+ "-pd",
135
+ "--print_decompositions",
136
+ action="store_true",
137
+ help="Print decompositions on plot",
138
+ )
113
139
 
114
140
  # Backend related arguments
115
- plot_parser.add_argument('-b',
116
- '--backends',
117
- nargs='*',
118
- default=['MPI', 'NCCL', 'MPI4JAX'],
119
- help='List of backends to include')
141
+ plot_parser.add_argument(
142
+ "-b",
143
+ "--backends",
144
+ nargs="*",
145
+ default=["MPI", "NCCL", "MPI4JAX"],
146
+ help="List of backends to include",
147
+ )
120
148
 
121
149
  # Scaling type argument
122
- plot_parser.add_argument('-sc',
123
- '--scaling',
124
- choices=['Weak', 'Strong'],
125
- required=True,
126
- help='Scaling type (Weak or Strong)')
150
+ plot_parser.add_argument(
151
+ "-sc",
152
+ "--scaling",
153
+ choices=["Weak", "Strong", "w", "s"],
154
+ required=True,
155
+ help="Scaling type (Weak or Strong)",
156
+ )
127
157
 
128
158
  # Label customization argument
129
159
  plot_parser.add_argument(
130
- '-l',
131
- '--label_text',
160
+ "-l",
161
+ "--label_text",
132
162
  type=str,
133
163
  help=
134
- 'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
135
- default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
164
+ ("Custom label for the plot. You can use placeholders: %%decomposition%% "
165
+ "(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), "
166
+ "%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)"
167
+ ),
168
+ default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
169
+ )
136
170
 
137
- subparsers.add_parser('label_help',help='Label customization help')
171
+ subparsers.add_parser("label_help", help="Label customization help")
138
172
 
139
173
  args = parser.parse_args()
140
-
174
+
141
175
  # if command was plot, then check if pdim_strategy is validat
142
- if args.command == 'plot':
143
- if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
144
- print(
145
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
146
- )
147
- args.pdim_strategy = ['plot_all']
148
-
149
- if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
150
- print(
151
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
152
- )
153
- args.pdim_strategy = ['plot_fastest']
154
- if args.plot_times is not None:
155
- args.plot_columns = args.plot_times
156
- elif args.plot_memory is not None:
157
- args.plot_columns = args.plot_memory
158
- else:
159
- raise ValueError('Either plot_times or plot_memory should be provided')
176
+ if args.command == "plot":
177
+ if "plot_all" in args.pdim_strategy and len(args.pdim_strategy) > 1:
178
+ print(
179
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
180
+ )
181
+ args.pdim_strategy = ["plot_all"]
182
+
183
+ if "plot_fastest" in args.pdim_strategy and len(
184
+ args.pdim_strategy) > 1:
185
+ print(
186
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
187
+ )
188
+ args.pdim_strategy = ["plot_fastest"]
189
+ if args.plot_times is not None:
190
+ args.plot_columns = args.plot_times
191
+ elif args.plot_memory is not None:
192
+ args.plot_columns = args.plot_memory
193
+ else:
194
+ raise ValueError(
195
+ "Either plot_times or plot_memory should be provided")
160
196
 
161
197
  return args
jax_hpc_profiler/main.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import sys
2
+ from typing import List, Optional
2
3
 
3
4
  from .create_argparse import create_argparser
4
5
  from .plotting import plot_strong_scaling, plot_weak_scaling
@@ -8,55 +9,56 @@ from .utils import clean_up_csv, concatenate_csvs
8
9
  def main():
9
10
  args = create_argparser()
10
11
 
11
- if args.command == 'concat':
12
+ if args.command == "concat":
12
13
  input_dir, output_dir = args.input, args.output
13
14
  concatenate_csvs(input_dir, output_dir)
14
- elif args.command == 'label_help':
15
+ elif args.command == "label_help":
15
16
  print(f"Customize the label text for the plot. using these commands.")
16
- print(' -- %m% or %methodname%: method name')
17
- print(' -- %f% or %function%: function name')
18
- print(' -- %pn% or %plot_name%: plot name')
19
- print(' -- %pr% or %precision%: precision')
20
- print(' -- %b% or %backend%: backend')
21
- print(' -- %p% or %pdims%: pdims')
22
- print(' -- %n% or %node%: node')
23
- elif args.command == 'plot':
24
- dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
25
- args.csv_files, args.precision, args.function_name, args.gpus,
26
- args.data_size, args.filter_pdims, args.pdim_strategy,
27
- args.backends, args.memory_units)
28
- if len(dataframes) == 0:
29
- print(f"No dataframes found for the given arguments. Exiting...")
30
- sys.exit(1)
31
- print(
32
- f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
33
- )
34
- # filter back the requested data sizes and gpus
35
-
36
- args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
37
- args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
17
+ print(" -- %m% or %methodname%: method name")
18
+ print(" -- %f% or %function%: function name")
19
+ print(" -- %pn% or %plot_name%: plot name")
20
+ print(" -- %pr% or %precision%: precision")
21
+ print(" -- %b% or %backend%: backend")
22
+ print(" -- %p% or %pdims%: pdims")
23
+ print(" -- %n% or %node%: node")
24
+ elif args.command == "plot":
38
25
 
39
- if len(args.gpus) == 0:
40
- print(f"No dataframes found for the given GPUs. Exiting...")
41
- sys.exit(1)
42
- if len(args.data_size) == 0:
43
- print(f"No dataframes found for the given data sizes. Exiting...")
44
- sys.exit(1)
45
-
46
- if args.scaling == 'Weak':
47
- plot_weak_scaling(dataframes, args.gpus, args.figure_size,
48
- args.output, args.dark_bg,
49
- args.print_decompositions, args.backends,
50
- args.precision, args.function_name,
51
- args.plot_columns, args.memory_units,
52
- args.label_text, args.pdim_strategy)
53
- elif args.scaling == 'Strong':
54
- plot_strong_scaling(dataframes, args.data_size, args.figure_size,
55
- args.output, args.dark_bg,
56
- args.print_decompositions, args.backends,
57
- args.precision, args.function_name,
58
- args.plot_columns, args.memory_units,
59
- args.label_text, args.pdim_strategy)
26
+ if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
27
+ plot_weak_scaling(
28
+ args.csv_files,
29
+ args.gpus,
30
+ args.data_size,
31
+ args.function_name,
32
+ args.precision,
33
+ args.filter_pdims,
34
+ args.pdim_strategy,
35
+ args.print_decompositions,
36
+ args.backends,
37
+ args.plot_columns,
38
+ args.memory_units,
39
+ args.label_text,
40
+ args.figure_size,
41
+ args.dark_bg,
42
+ args.output,
43
+ )
44
+ elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
45
+ plot_strong_scaling(
46
+ args.csv_files,
47
+ args.gpus,
48
+ args.data_size,
49
+ args.function_name,
50
+ args.precision,
51
+ args.filter_pdims,
52
+ args.pdim_strategy,
53
+ args.print_decompositions,
54
+ args.backends,
55
+ args.plot_columns,
56
+ args.memory_units,
57
+ args.label_text,
58
+ args.figure_size,
59
+ args.dark_bg,
60
+ args.output,
61
+ )
60
62
 
61
63
 
62
64
  if __name__ == "__main__":
@@ -4,22 +4,24 @@ from typing import Dict, List, Optional
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import pandas as pd
7
+ import seaborn as sns
7
8
  from matplotlib.axes import Axes
8
9
  from matplotlib.patches import FancyBboxPatch
9
10
 
10
- from .utils import inspect_df, plot_with_pdims_strategy, inspect_data
11
+ from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
11
12
 
12
- np.seterr(divide='ignore')
13
- plt.rcParams.update({'font.size': 15})
13
+ np.seterr(divide="ignore")
14
14
 
15
15
 
16
- def configure_axes(ax: Axes,
17
- x_values: List[int],
18
- y_values: List[float],
19
- xlabel: str,
20
- title: str,
21
- plotting_memory: bool = False,
22
- memory_units: str = 'bytes'):
16
+ def configure_axes(
17
+ ax: Axes,
18
+ x_values: List[int],
19
+ y_values: List[float],
20
+ xlabel: str,
21
+ title: str,
22
+ plotting_memory: bool = False,
23
+ memory_units: str = "bytes",
24
+ ):
23
25
  """
24
26
  Configure the axes for the plot.
25
27
 
@@ -34,16 +36,17 @@ def configure_axes(ax: Axes,
34
36
  xlabel : str
35
37
  The label for the x-axis.
36
38
  """
37
- ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
39
+ ylabel = ("Time (milliseconds)"
40
+ if not plotting_memory else f"Memory ({memory_units})")
38
41
  f2 = lambda x: np.log2(x)
39
42
  g2 = lambda x: 2**x
40
43
  ax.set_xlim([min(x_values), max(x_values)])
41
44
  y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
42
45
  ax.set_title(title)
43
46
  ax.set_ylim([y_min, y_max])
44
- ax.set_xscale('function', functions=(f2, g2))
47
+ ax.set_xscale("function", functions=(f2, g2))
45
48
  if not plotting_memory:
46
- ax.set_yscale('symlog')
49
+ ax.set_yscale("symlog")
47
50
  time_ticks = [
48
51
  10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
49
52
  int(np.ceil(np.log10(y_max))))
@@ -53,31 +56,35 @@ def configure_axes(ax: Axes,
53
56
  ax.set_xlabel(xlabel)
54
57
  ax.set_ylabel(ylabel)
55
58
  for x_value in x_values:
56
- ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
57
- ax.legend(loc='lower center',
58
- bbox_to_anchor=(0.5, 0.05),
59
- ncol=4,
60
- fontsize="x-large",
61
- prop={'size': 14})
62
-
63
-
64
- def plot_scaling(dataframes: Dict[str, pd.DataFrame],
65
- fixed_sizes: List[int],
66
- size_column: str,
67
- fixed_column: str,
68
- xlabel: str,
69
- title: str,
70
- figure_size: tuple = (6, 4),
71
- output: Optional[str] = None,
72
- dark_bg: bool = False,
73
- print_decompositions: bool = False,
74
- backends: List[str] = ['NCCL'],
75
- precisions: List[str] = ['float32'],
76
- functions: List[str] | None = None,
77
- plot_columns: List[str] = ['mean_time'],
78
- memory_units: str = 'bytes',
79
- label_text: str = 'plot',
80
- pdims_strategy: str = 'plot_fastest'):
59
+ ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
60
+ ax.legend(
61
+ loc="lower center",
62
+ bbox_to_anchor=(0.5, 0.05),
63
+ ncol=4,
64
+ fontsize="x-large",
65
+ prop={"size": 14},
66
+ )
67
+
68
+
69
+ def plot_scaling(
70
+ dataframes: Dict[str, pd.DataFrame],
71
+ fixed_sizes: List[int],
72
+ size_column: str,
73
+ fixed_column: str,
74
+ xlabel: str,
75
+ title: str,
76
+ figure_size: tuple = (6, 4),
77
+ output: Optional[str] = None,
78
+ dark_bg: bool = False,
79
+ print_decompositions: bool = False,
80
+ backends: List[str] = ["NCCL"],
81
+ precisions: List[str] = ["float32"],
82
+ functions: List[str] | None = None,
83
+ plot_columns: List[str] = ["mean_time"],
84
+ memory_units: str = "bytes",
85
+ label_text: str = "plot",
86
+ pdims_strategy: List[str] = ["plot_fastest"],
87
+ ):
81
88
  """
82
89
  General scaling plot function based on the number of GPUs or data size.
83
90
 
@@ -106,8 +113,9 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
106
113
  pdims_strategy : str, optional
107
114
  Strategy for plotting pdims ('plot_all' or 'plot_fastest'), by default 'plot_fastest'.
108
115
  """
116
+
109
117
  if dark_bg:
110
- plt.style.use('dark_background')
118
+ plt.style.use("dark_background")
111
119
 
112
120
  num_subplots = len(fixed_sizes)
113
121
  num_rows = int(np.ceil(np.sqrt(num_subplots)))
@@ -118,7 +126,7 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
118
126
  axs = axs.flatten()
119
127
  else:
120
128
  axs = [axs]
121
-
129
+
122
130
  for i, fixed_size in enumerate(fixed_sizes):
123
131
  ax: Axes = axs[i]
124
132
 
@@ -131,30 +139,44 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
131
139
  continue
132
140
  filtered_method_df = filtered_method_df.sort_values(
133
141
  by=[size_column])
134
- functions = pd.unique(filtered_method_df['function']
135
- ) if functions is None else functions
142
+ functions = (pd.unique(filtered_method_df["function"])
143
+ if functions is None else functions)
136
144
  combinations = product(backends, precisions, functions,
137
145
  plot_columns)
138
146
 
139
147
  for backend, precision, function, plot_column in combinations:
140
-
148
+
141
149
  filtered_params_df = filtered_method_df[
142
- (filtered_method_df['backend'] == backend)
143
- & (filtered_method_df['precision'] == precision) &
144
- (filtered_method_df['function'] == function)]
150
+ (filtered_method_df["backend"] == backend)
151
+ & (filtered_method_df["precision"] == precision)
152
+ & (filtered_method_df["function"] == function)]
145
153
  if filtered_params_df.empty:
146
154
  continue
147
155
  x_vals, y_vals = plot_with_pdims_strategy(
148
- ax, filtered_params_df, method, pdims_strategy,
149
- print_decompositions, size_column, plot_column, label_text)
156
+ ax,
157
+ filtered_params_df,
158
+ method,
159
+ pdims_strategy,
160
+ print_decompositions,
161
+ size_column,
162
+ plot_column,
163
+ label_text,
164
+ )
150
165
 
151
166
  x_values.extend(x_vals)
152
167
  y_values.extend(y_vals)
153
-
154
- if len(x_values) != 0:
155
- plotting_memory = 'time' not in plot_columns[0].lower()
156
- configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
157
- plotting_memory, memory_units)
168
+
169
+ if len(x_values) != 0:
170
+ plotting_memory = "time" not in plot_columns[0].lower()
171
+ configure_axes(
172
+ ax,
173
+ x_values,
174
+ y_values,
175
+ f"{title} {fixed_size}",
176
+ xlabel,
177
+ plotting_memory,
178
+ memory_units,
179
+ )
158
180
 
159
181
  for i in range(num_subplots, num_rows * num_cols):
160
182
  fig.delaxes(axs[i])
@@ -170,48 +192,117 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
170
192
  if output is None:
171
193
  plt.show()
172
194
  else:
173
- plt.savefig(output, bbox_inches='tight', transparent=False)
174
-
175
-
176
- def plot_strong_scaling(dataframes: Dict[str, pd.DataFrame],
177
- fixed_data_size: List[int],
178
- figure_size: tuple = (6, 4),
179
- output: Optional[str] = None,
180
- dark_bg: bool = False,
181
- print_decompositions: bool = False,
182
- backends: List[str] = ['NCCL'],
183
- precisions: List[str] = ['float32'],
184
- functions: List[str] | None = None,
185
- plot_columns: List[str] = ['mean_time'],
186
- memory_units: str = 'bytes',
187
- label_text: str = 'plot',
188
- pdims_strategy: str = 'plot_fastest'):
195
+ plt.savefig(output, bbox_inches="tight", transparent=True)
196
+
197
+
198
+ def plot_strong_scaling(
199
+ csv_files: List[str],
200
+ fixed_gpu_size: Optional[List[int]] = None,
201
+ fixed_data_size: Optional[List[int]] = None,
202
+ functions: List[str] | None = None,
203
+ precisions: List[str] = ["float32"],
204
+ pdims: Optional[List[str]] = None,
205
+ pdims_strategy: List[str] = ["plot_fastest"],
206
+ print_decompositions: bool = False,
207
+ backends: List[str] = ["NCCL"],
208
+ plot_columns: List[str] = ["mean_time"],
209
+ memory_units: str = "bytes",
210
+ label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
211
+ figure_size: tuple = (6, 4),
212
+ dark_bg: bool = False,
213
+ output: Optional[str] = None,
214
+ ):
189
215
  """
190
216
  Plot strong scaling based on the number of GPUs.
191
217
  """
192
- plot_scaling(dataframes, fixed_data_size, 'gpus', 'x', 'Number of GPUs',
193
- 'Data size', figure_size, output, dark_bg,
194
- print_decompositions, backends, precisions, functions,
195
- plot_columns, memory_units, label_text, pdims_strategy)
196
-
197
-
198
- def plot_weak_scaling(dataframes: Dict[str, pd.DataFrame],
199
- fixed_gpu_size: List[int],
200
- figure_size: tuple = (6, 4),
201
- output: Optional[str] = None,
202
- dark_bg: bool = False,
203
- print_decompositions: bool = False,
204
- backends: List[str] = ['NCCL'],
205
- precisions: List[str] = ['float32'],
206
- functions: List[str] | None = None,
207
- plot_columns: List[str] = ['mean_time'],
208
- memory_units: str = 'bytes',
209
- label_text: str = 'plot',
210
- pdims_strategy: str = 'plot_fastest'):
218
+
219
+ dataframes, _, available_data_sizes = clean_up_csv(
220
+ csv_files,
221
+ precisions,
222
+ functions,
223
+ fixed_gpu_size,
224
+ fixed_data_size,
225
+ pdims,
226
+ pdims_strategy,
227
+ backends,
228
+ memory_units,
229
+ )
230
+ if len(dataframes) == 0:
231
+ print(f"No dataframes found for the given arguments. Exiting...")
232
+ return
233
+
234
+ plot_scaling(
235
+ dataframes,
236
+ available_data_sizes,
237
+ "gpus",
238
+ "x",
239
+ "Number of GPUs",
240
+ "Data size",
241
+ figure_size,
242
+ output,
243
+ dark_bg,
244
+ print_decompositions,
245
+ backends,
246
+ precisions,
247
+ functions,
248
+ plot_columns,
249
+ memory_units,
250
+ label_text,
251
+ pdims_strategy,
252
+ )
253
+
254
+
255
+ def plot_weak_scaling(
256
+ csv_files: List[str],
257
+ fixed_gpu_size: Optional[List[int]] = None,
258
+ fixed_data_size: Optional[List[int]] = None,
259
+ functions: List[str] | None = None,
260
+ precisions: List[str] = ["float32"],
261
+ pdims: Optional[List[str]] = None,
262
+ pdims_strategy: List[str] = ["plot_fastest"],
263
+ print_decompositions: bool = False,
264
+ backends: List[str] = ["NCCL"],
265
+ plot_columns: List[str] = ["mean_time"],
266
+ memory_units: str = "bytes",
267
+ label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
268
+ figure_size: tuple = (6, 4),
269
+ dark_bg: bool = False,
270
+ output: Optional[str] = None,
271
+ ):
211
272
  """
212
273
  Plot weak scaling based on the data size.
213
274
  """
214
- plot_scaling(dataframes, fixed_gpu_size, 'x', 'gpus', 'Data size',
215
- 'Number of GPUs', figure_size, output, dark_bg,
216
- print_decompositions, backends, precisions, functions,
217
- plot_columns, memory_units, label_text, pdims_strategy)
275
+ dataframes, available_gpu_counts, _ = clean_up_csv(
276
+ csv_files,
277
+ precisions,
278
+ functions,
279
+ fixed_gpu_size,
280
+ fixed_data_size,
281
+ pdims,
282
+ pdims_strategy,
283
+ backends,
284
+ memory_units,
285
+ )
286
+ if len(dataframes) == 0:
287
+ print(f"No dataframes found for the given arguments. Exiting...")
288
+ return
289
+
290
+ plot_scaling(
291
+ dataframes,
292
+ available_gpu_counts,
293
+ "x",
294
+ "gpus",
295
+ "Data size",
296
+ "Number of GPUs",
297
+ figure_size,
298
+ output,
299
+ dark_bg,
300
+ print_decompositions,
301
+ backends,
302
+ precisions,
303
+ functions,
304
+ plot_columns,
305
+ memory_units,
306
+ label_text,
307
+ pdims_strategy,
308
+ )
jax_hpc_profiler/timer.py CHANGED
@@ -16,36 +16,45 @@ from tabulate import tabulate
16
16
 
17
17
  class Timer:
18
18
 
19
- def __init__(self, save_jaxpr=False):
20
- self.jit_time = None
19
+ def __init__(self, save_jaxpr=False, jax_fn=True, devices=None):
20
+ self.jit_time = 0.0
21
21
  self.times = []
22
22
  self.profiling_data = {}
23
23
  self.compiled_code = {}
24
24
  self.save_jaxpr = save_jaxpr
25
+ self.jax_fn = jax_fn
26
+ self.devices = devices
25
27
 
26
28
  def _normalize_memory_units(self, memory_analysis) -> str:
27
29
 
28
- sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
30
+ if not self.jax_fn:
31
+ return memory_analysis
32
+
33
+ sizes_str = ["B", "KB", "MB", "GB", "TB", "PB"]
29
34
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
30
- factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
35
+ factor = 0 if memory_analysis == 0 else int(
36
+ np.log10(memory_analysis) // 3)
31
37
 
32
38
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
33
39
 
34
40
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
35
41
  if memory_analysis is None:
36
42
  return None, None, None, None
37
- return (memory_analysis.generated_code_size_in_bytes,
38
- memory_analysis.argument_size_in_bytes,
39
- memory_analysis.output_size_in_bytes,
40
- memory_analysis.temp_size_in_bytes)
43
+ return (
44
+ memory_analysis.generated_code_size_in_bytes,
45
+ memory_analysis.argument_size_in_bytes,
46
+ memory_analysis.output_size_in_bytes,
47
+ memory_analysis.temp_size_in_bytes,
48
+ )
41
49
 
42
50
  def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
43
51
  start = time.perf_counter()
44
52
  out = fun(*args)
45
- if ndarray_arg is None:
46
- out.block_until_ready()
47
- else:
48
- out[ndarray_arg].block_until_ready()
53
+ if self.jax_fn:
54
+ if ndarray_arg is None:
55
+ out.block_until_ready()
56
+ else:
57
+ out[ndarray_arg].block_until_ready()
49
58
  end = time.perf_counter()
50
59
  self.jit_time = (end - start) * 1e3
51
60
 
@@ -53,78 +62,90 @@ class Timer:
53
62
  jaxpr = make_jaxpr(fun)(*args)
54
63
  self.compiled_code["JAXPR"] = jaxpr.pretty_print()
55
64
 
56
- lowered = jax.jit(fun).lower(*args)
57
- compiled = lowered.compile()
58
- memory_analysis = self._read_memory_analysis(
59
- compiled.memory_analysis())
60
-
61
- self.compiled_code["LOWERED"] = lowered.as_text()
62
- self.compiled_code["COMPILED"] = compiled.as_text()
63
- self.profiling_data["generated_code"] = memory_analysis[0]
64
- self.profiling_data["argument_size"] = memory_analysis[1]
65
- self.profiling_data["output_size"] = memory_analysis[2]
66
- self.profiling_data["temp_size"] = memory_analysis[3]
65
+ if self.jax_fn:
66
+ lowered = jax.jit(fun).lower(*args)
67
+ compiled = lowered.compile()
68
+ memory_analysis = self._read_memory_analysis(
69
+ compiled.memory_analysis())
70
+
71
+ self.compiled_code["LOWERED"] = lowered.as_text()
72
+ self.compiled_code["COMPILED"] = compiled.as_text()
73
+ self.profiling_data["generated_code"] = memory_analysis[0]
74
+ self.profiling_data["argument_size"] = memory_analysis[1]
75
+ self.profiling_data["output_size"] = memory_analysis[2]
76
+ self.profiling_data["temp_size"] = memory_analysis[3]
77
+
67
78
  return out
68
79
 
69
80
  def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
70
81
  start = time.perf_counter()
71
82
  out = fun(*args)
72
- if ndarray_arg is None:
73
- out.block_until_ready()
74
- else:
75
- out[ndarray_arg].block_until_ready()
83
+ if self.jax_fn:
84
+ if ndarray_arg is None:
85
+ out.block_until_ready()
86
+ else:
87
+ out[ndarray_arg].block_until_ready()
76
88
  end = time.perf_counter()
77
89
  self.times.append((end - start) * 1e3)
78
90
  return out
79
91
 
80
92
  def _get_mean_times(self) -> np.ndarray:
81
- if jax.device_count() == 1:
93
+ if jax.device_count() == 1 or jax.process_count() == 1:
82
94
  return np.array(self.times)
83
95
 
84
- devices = mesh_utils.create_device_mesh((jax.device_count(), ))
85
- mesh = Mesh(devices, ('x', ))
86
- sharding = NamedSharding(mesh, P('x'))
96
+ if self.devices is None:
97
+ self.devices = jax.devices()
98
+
99
+ mesh = jax.make_mesh((len(self.devices), ), ("x", ),
100
+ devices=self.devices)
101
+ sharding = NamedSharding(mesh, P("x"))
87
102
 
88
103
  times_array = jnp.array(self.times)
89
104
  global_shape = (jax.device_count(), times_array.shape[0])
90
105
  global_times = jax.make_array_from_callback(
91
106
  shape=global_shape,
92
107
  sharding=sharding,
93
- data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
108
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
109
+ )
94
110
 
95
111
  @partial(shard_map,
96
112
  mesh=mesh,
97
- in_specs=P('x'),
113
+ in_specs=P("x"),
98
114
  out_specs=P(),
99
115
  check_rep=False)
100
116
  def get_mean_times(times):
101
- return jax.lax.pmean(times, axis_name='x')
117
+ return jax.lax.pmean(times, axis_name="x")
102
118
 
103
119
  times_array = get_mean_times(global_times)
104
120
  times_array.block_until_ready()
105
121
  return np.array(times_array.addressable_data(0)[0])
106
122
 
107
- def report(self,
108
- csv_filename: str,
109
- function: str,
110
- x: int,
111
- y: int | None = None,
112
- z: int | None = None,
113
- precision: str = "float32",
114
- px: int = 1,
115
- py: int = 1,
116
- backend: str = "NCCL",
117
- nodes: int = 1,
118
- md_filename: str | None = None,
119
- extra_info: dict = {}):
123
+ def report(
124
+ self,
125
+ csv_filename: str,
126
+ function: str,
127
+ x: int,
128
+ y: int | None = None,
129
+ z: int | None = None,
130
+ precision: str = "float32",
131
+ px: int = 1,
132
+ py: int = 1,
133
+ backend: str = "NCCL",
134
+ nodes: int = 1,
135
+ md_filename: str | None = None,
136
+ extra_info: dict = {},
137
+ ):
120
138
 
121
139
  if md_filename is None:
122
- dirname, filename = os.path.dirname(
123
- csv_filename), os.path.splitext(
124
- os.path.basename(csv_filename))[0]
140
+ dirname, filename = (
141
+ os.path.dirname(csv_filename),
142
+ os.path.splitext(os.path.basename(csv_filename))[0],
143
+ )
125
144
  report_folder = filename if dirname == "" else f"{dirname}/{filename}"
126
145
  os.makedirs(report_folder, exist_ok=True)
127
- md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
146
+ md_filename = (
147
+ f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
148
+ )
128
149
 
129
150
  y = x if y is None else y
130
151
  z = x if z is None else z
@@ -138,10 +159,17 @@ class Timer:
138
159
  std_time = np.std(times_array)
139
160
  last_time = times_array[-1]
140
161
 
141
- generated_code = self.profiling_data["generated_code"]
142
- argument_size = self.profiling_data["argument_size"]
143
- output_size = self.profiling_data["output_size"]
144
- temp_size = self.profiling_data["temp_size"]
162
+ if self.jax_fn:
163
+
164
+ generated_code = self.profiling_data["generated_code"]
165
+ argument_size = self.profiling_data["argument_size"]
166
+ output_size = self.profiling_data["output_size"]
167
+ temp_size = self.profiling_data["temp_size"]
168
+ else:
169
+ generated_code = "N/A"
170
+ argument_size = "N/A"
171
+ output_size = "N/A"
172
+ temp_size = "N/A"
145
173
 
146
174
  csv_line = (
147
175
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
@@ -149,7 +177,7 @@ class Timer:
149
177
  f"{generated_code},{argument_size},{output_size},{temp_size}\n"
150
178
  )
151
179
 
152
- with open(csv_filename, 'a') as f:
180
+ with open(csv_filename, "a") as f:
153
181
  f.write(csv_line)
154
182
 
155
183
  param_dict = {
@@ -175,44 +203,56 @@ class Timer:
175
203
  "Argument Size": self._normalize_memory_units(argument_size),
176
204
  "Output Size": self._normalize_memory_units(output_size),
177
205
  "Temporary Size": self._normalize_memory_units(temp_size),
178
- "FLOPS": self.profiling_data["FLOPS"]
179
206
  }
180
207
  iteration_runs = {}
181
208
  for i in range(len(times_array)):
182
209
  iteration_runs[f"Run {i}"] = times_array[i]
183
210
 
184
- with open(md_filename, 'w') as f:
211
+ with open(md_filename, "w") as f:
185
212
  f.write(f"# Reporting for {function}\n")
186
213
  f.write(f"## Parameters\n")
187
214
  f.write(
188
- tabulate(param_dict.items(),
189
- headers=["Parameter", "Value"],
190
- tablefmt='github'))
215
+ tabulate(
216
+ param_dict.items(),
217
+ headers=["Parameter", "Value"],
218
+ tablefmt="github",
219
+ ))
191
220
  f.write("\n---\n")
192
221
  f.write(f"## Profiling Data\n")
193
222
  f.write(
194
- tabulate(profiling_result.items(),
195
- headers=["Parameter", "Value"],
196
- tablefmt='github'))
223
+ tabulate(
224
+ profiling_result.items(),
225
+ headers=["Parameter", "Value"],
226
+ tablefmt="github",
227
+ ))
197
228
  f.write("\n---\n")
198
229
  f.write(f"## Iteration Runs\n")
199
230
  f.write(
200
- tabulate(iteration_runs.items(),
201
- headers=["Iteration", "Time"],
202
- tablefmt='github'))
203
- f.write("\n---\n")
204
- f.write(f"## Compiled Code\n")
205
- f.write(f"```hlo\n")
206
- f.write(self.compiled_code["COMPILED"])
207
- f.write(f"\n```\n")
208
- f.write("\n---\n")
209
- f.write(f"## Lowered Code\n")
210
- f.write(f"```hlo\n")
211
- f.write(self.compiled_code["LOWERED"])
212
- f.write(f"\n```\n")
213
- f.write("\n---\n")
214
- if self.save_jaxpr:
215
- f.write(f"## JAXPR\n")
216
- f.write(f"```haskel\n")
217
- f.write(self.compiled_code["JAXPR"])
231
+ tabulate(
232
+ iteration_runs.items(),
233
+ headers=["Iteration", "Time"],
234
+ tablefmt="github",
235
+ ))
236
+ if self.jax_fn:
237
+ f.write("\n---\n")
238
+ f.write(f"## Compiled Code\n")
239
+ f.write(f"```hlo\n")
240
+ f.write(self.compiled_code["COMPILED"])
218
241
  f.write(f"\n```\n")
242
+ f.write("\n---\n")
243
+ f.write(f"## Lowered Code\n")
244
+ f.write(f"```hlo\n")
245
+ f.write(self.compiled_code["LOWERED"])
246
+ f.write(f"\n```\n")
247
+ f.write("\n---\n")
248
+ if self.save_jaxpr:
249
+ f.write(f"## JAXPR\n")
250
+ f.write(f"```haskel\n")
251
+ f.write(self.compiled_code["JAXPR"])
252
+ f.write(f"\n```\n")
253
+
254
+ # Reset the timer
255
+ self.jit_time = 0.0
256
+ self.times = []
257
+ self.profiling_data = {}
258
+ self.compiled_code = {}
jax_hpc_profiler/utils.py CHANGED
@@ -250,7 +250,7 @@ def clean_up_csv(
250
250
  pdims_strategy: List[str] = ['plot_fastest'],
251
251
  backends: List[str] = ['MPI', 'NCCL', 'MPI4JAX'],
252
252
  memory_units: str = 'KB',
253
- ) -> Tuple[Dict[str, pd.DataFrame], set, set]:
253
+ ) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
254
254
  """
255
255
  Clean up and aggregate data from CSV files.
256
256
 
@@ -336,7 +336,7 @@ def clean_up_csv(
336
336
  # Filter data sizes
337
337
  if data_sizes:
338
338
  df = df[df['x'].isin(data_sizes)]
339
-
339
+
340
340
  # Filter pdims
341
341
  if pdims:
342
342
  px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
@@ -394,4 +394,17 @@ def clean_up_csv(
394
394
  else:
395
395
  dataframes[file_name] = pd.concat([dataframes[file_name], df])
396
396
 
397
+ print(f"requested GPUS: {gpus} available GPUS: {available_gpu_counts}")
398
+ print(
399
+ f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
400
+ )
401
+
402
+ available_gpu_counts = (available_gpu_counts if gpus is None else [
403
+ gpu for gpu in gpus if gpu in available_gpu_counts
404
+ ])
405
+ available_data_sizes = (available_data_sizes if data_sizes is None else [
406
+ data_size for data_size in data_sizes
407
+ if data_size in available_data_sizes
408
+ ])
409
+
397
410
  return dataframes, available_gpu_counts, available_data_sizes
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
- License: GNU GENERAL PUBLIC LICENSE
6
+ License: GNU GENERAL PUBLIC LICENSE
7
7
  Version 3, 29 June 2007
8
8
 
9
9
  Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=CSdl76LvaTVVn43dkwpVyiIkyl4lHlDCiI5jvUrIoj0,6059
3
+ jax_hpc_profiler/main.py,sha256=2zPVTGRgFkYV75EJA1eoOqf92gCRXAtg-28cFgRy3Bw,2164
4
+ jax_hpc_profiler/plotting.py,sha256=8ELOB_Yv_AdSVWtS-jrRNm0HtK5FgKwf_ljeNRfdp14,9087
5
+ jax_hpc_profiler/timer.py,sha256=p7MUcbd2H4_tRAevhG9T4jJ8XL-liComvJn2sis4psM,9209
6
+ jax_hpc_profiler/utils.py,sha256=hSsS34i46WdCR9XRW1-02fI_k0RUty78imnI-xAc-tY,14644
7
+ jax_hpc_profiler-0.2.9.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.9.dist-info/METADATA,sha256=CelrNVm13lK7L1ZkdOqD8Tm7qLBIF1oCyaghzDdrLRg,49270
9
+ jax_hpc_profiler-0.2.9.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
10
+ jax_hpc_profiler-0.2.9.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.9.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.9.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
3
- jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
4
- jax_hpc_profiler/plotting.py,sha256=jnWpyfZOB5w5L4BNIU5euKbUt__wfGPyCjUITrEwScM,8431
5
- jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
6
- jax_hpc_profiler/utils.py,sha256=tUXnNHwQSSCqA6XxLd4MoV2gyFIC7ncB-uvVc2INhms,14140
7
- jax_hpc_profiler-0.2.8.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.8.dist-info/METADATA,sha256=NDOKx7RzKr4Z4mJnY457-4uxaP2TDhwWYH5rZRxfukQ,49250
9
- jax_hpc_profiler-0.2.8.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
10
- jax_hpc_profiler-0.2.8.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.8.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.8.dist-info/RECORD,,