jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,15 @@
1
1
  from .create_argparse import create_argparser
2
- from .plotting import plot_strong_scaling, plot_weak_scaling
2
+ from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
3
3
  from .timer import Timer
4
4
  from .utils import clean_up_csv, concatenate_csvs, plot_with_pdims_strategy
5
5
 
6
6
  __all__ = [
7
- 'create_argparser', 'plot_strong_scaling', 'plot_weak_scaling', 'Timer',
8
- 'clean_up_csv', 'concatenate_csvs', 'plot_with_pdims_strategy'
7
+ 'create_argparser',
8
+ 'plot_strong_scaling',
9
+ 'plot_weak_scaling',
10
+ 'plot_weak_fixed_scaling',
11
+ 'Timer',
12
+ 'clean_up_csv',
13
+ 'concatenate_csvs',
14
+ 'plot_with_pdims_strategy',
9
15
  ]
@@ -10,201 +10,209 @@ def create_argparser():
10
10
  argparse.Namespace
11
11
  Parsed and validated arguments.
12
12
  """
13
- parser = argparse.ArgumentParser(
14
- description="HPC Plotter for benchmarking data")
13
+ parser = argparse.ArgumentParser(description='HPC Plotter for benchmarking data')
15
14
 
16
15
  # Group for concatenation to ensure mutually exclusive behavior
17
- subparsers = parser.add_subparsers(dest="command", required=True)
16
+ subparsers = parser.add_subparsers(dest='command', required=True)
18
17
 
19
- concat_parser = subparsers.add_parser("concat",
20
- help="Concatenate CSV files")
21
- concat_parser.add_argument("input",
22
- type=str,
23
- help="Input directory for concatenation")
24
- concat_parser.add_argument("output",
25
- type=str,
26
- help="Output directory for concatenation")
18
+ concat_parser = subparsers.add_parser('concat', help='Concatenate CSV files')
19
+ concat_parser.add_argument('input', type=str, help='Input directory for concatenation')
20
+ concat_parser.add_argument('output', type=str, help='Output directory for concatenation')
27
21
 
28
22
  # Arguments for plotting
29
- plot_parser = subparsers.add_parser("plot", help="Plot CSV data")
30
- plot_parser.add_argument("-f",
31
- "--csv_files",
32
- nargs="+",
33
- help="List of CSV files to plot",
34
- required=True)
35
- plot_parser.add_argument(
36
- "-g",
37
- "--gpus",
38
- nargs="*",
23
+ plot_parser = subparsers.add_parser('plot', help='Plot CSV data')
24
+ plot_parser.add_argument(
25
+ '-f', '--csv_files', nargs='+', help='List of CSV files to plot', required=True
26
+ )
27
+ plot_parser.add_argument(
28
+ '-g',
29
+ '--gpus',
30
+ nargs='*',
39
31
  type=int,
40
- help="List of number of GPUs to plot",
32
+ help='List of number of GPUs to plot',
41
33
  default=None,
42
34
  )
43
35
  plot_parser.add_argument(
44
- "-d",
45
- "--data_size",
46
- nargs="*",
36
+ '-d',
37
+ '--data_size',
38
+ nargs='*',
47
39
  type=int,
48
- help="List of data sizes to plot",
40
+ help='List of data sizes to plot',
49
41
  default=None,
50
42
  )
51
43
 
52
44
  # pdims related arguments
53
45
  plot_parser.add_argument(
54
- "-fd",
55
- "--filter_pdims",
56
- nargs="*",
57
- help="List of pdims to filter, e.g., 1x4 2x2 4x8",
46
+ '-fd',
47
+ '--filter_pdims',
48
+ nargs='*',
49
+ help='List of pdims to filter, e.g., 1x4 2x2 4x8',
58
50
  default=None,
59
51
  )
60
52
  plot_parser.add_argument(
61
- "-ps",
62
- "--pdim_strategy",
63
- choices=["plot_all", "plot_fastest", "slab_yz", "slab_xy", "pencils"],
64
- nargs="*",
65
- default=["plot_fastest"],
66
- help="Strategy for plotting pdims",
53
+ '-ps',
54
+ '--pdim_strategy',
55
+ choices=['plot_all', 'plot_fastest', 'slab_yz', 'slab_xy', 'pencils'],
56
+ nargs='*',
57
+ default=['plot_fastest'],
58
+ help='Strategy for plotting pdims',
67
59
  )
68
60
 
69
61
  # Function and precision related arguments
70
62
  plot_parser.add_argument(
71
- "-pr",
72
- "--precision",
73
- choices=["float32", "float64"],
74
- default=["float32", "float64"],
75
- nargs="*",
76
- help="Precision to filter by (float32 or float64)",
63
+ '-pr',
64
+ '--precision',
65
+ choices=['float32', 'float64'],
66
+ default=['float32', 'float64'],
67
+ nargs='*',
68
+ help='Precision to filter by (float32 or float64)',
77
69
  )
78
70
  plot_parser.add_argument(
79
- "-fn",
80
- "--function_name",
81
- nargs="+",
82
- help="Function names to filter",
71
+ '-fn',
72
+ '--function_name',
73
+ nargs='+',
74
+ help='Function names to filter',
83
75
  default=None,
84
76
  )
85
77
 
86
78
  # Time or memory related arguments
87
79
  plotting_group = plot_parser.add_mutually_exclusive_group(required=True)
88
80
  plotting_group.add_argument(
89
- "-pt",
90
- "--plot_times",
91
- nargs="*",
81
+ '-pt',
82
+ '--plot_times',
83
+ nargs='*',
92
84
  choices=[
93
- "jit_time",
94
- "min_time",
95
- "max_time",
96
- "mean_time",
97
- "std_time",
98
- "last_time",
85
+ 'jit_time',
86
+ 'min_time',
87
+ 'max_time',
88
+ 'mean_time',
89
+ 'std_time',
90
+ 'last_time',
99
91
  ],
100
- help="Time columns to plot",
92
+ help='Time columns to plot',
101
93
  )
102
94
  plotting_group.add_argument(
103
- "-pm",
104
- "--plot_memory",
105
- nargs="*",
106
- choices=[
107
- "generated_code", "argument_size", "output_size", "temp_size"
108
- ],
109
- help="Memory columns to plot",
95
+ '-pm',
96
+ '--plot_memory',
97
+ nargs='*',
98
+ choices=['generated_code', 'argument_size', 'output_size', 'temp_size'],
99
+ help='Memory columns to plot',
110
100
  )
111
101
  plot_parser.add_argument(
112
- "-mu",
113
- "--memory_units",
114
- default="GB",
115
- help="Memory units to plot (KB, MB, GB, TB)",
102
+ '-mu',
103
+ '--memory_units',
104
+ default='GB',
105
+ help='Memory units to plot (KB, MB, GB, TB)',
116
106
  )
117
107
 
118
108
  # Plot customization arguments
119
- plot_parser.add_argument("-fs",
120
- "--figure_size",
121
- nargs=2,
122
- type=int,
123
- help="Figure size",
124
- default=(10, 6))
125
- plot_parser.add_argument("-o",
126
- "--output",
127
- help="Output file (if none then only show plot)",
128
- default=None)
129
- plot_parser.add_argument("-db",
130
- "--dark_bg",
131
- action="store_true",
132
- help="Use dark background for plotting")
133
- plot_parser.add_argument(
134
- "-pd",
135
- "--print_decompositions",
136
- action="store_true",
137
- help="Print decompositions on plot",
109
+ plot_parser.add_argument(
110
+ '-fs', '--figure_size', nargs=2, type=int, help='Figure size', default=(10, 6)
111
+ )
112
+ plot_parser.add_argument(
113
+ '-o', '--output', help='Output file (if none then only show plot)', default=None
114
+ )
115
+ plot_parser.add_argument(
116
+ '-db', '--dark_bg', action='store_true', help='Use dark background for plotting'
117
+ )
118
+ plot_parser.add_argument(
119
+ '-pd',
120
+ '--print_decompositions',
121
+ action='store_true',
122
+ help='Print decompositions on plot',
138
123
  )
139
124
 
140
125
  # Backend related arguments
141
126
  plot_parser.add_argument(
142
- "-b",
143
- "--backends",
144
- nargs="*",
145
- default=["MPI", "NCCL", "MPI4JAX"],
146
- help="List of backends to include",
127
+ '-b',
128
+ '--backends',
129
+ nargs='*',
130
+ default=['MPI', 'NCCL', 'MPI4JAX'],
131
+ help='List of backends to include',
147
132
  )
148
133
 
149
134
  # Scaling type argument
150
135
  plot_parser.add_argument(
151
- "-sc",
152
- "--scaling",
153
- choices=["Weak", "Strong", "w", "s"],
136
+ '-sc',
137
+ '--scaling',
138
+ choices=['Weak', 'Strong', 'WeakFixed', 'w', 's', 'wf'],
154
139
  required=True,
155
- help="Scaling type (Weak or Strong)",
140
+ help='Scaling type (Strong, Weak, or WeakFixed)',
141
+ )
142
+
143
+ # Weak-scaling specific options
144
+ plot_parser.add_argument(
145
+ '--weak_ideal_line',
146
+ action='store_true',
147
+ help='Overlay an ideal flat line for weak scaling (Weak mode only)',
148
+ )
149
+ plot_parser.add_argument(
150
+ '--weak_reverse_axes',
151
+ action='store_true',
152
+ help=(
153
+ 'Weak mode only: put data size on the x-axis and annotate each point with GPUs instead '
154
+ 'of data size. Requires --gpus and --data_size with equal lengths.'
155
+ ),
156
156
  )
157
157
 
158
158
  # Label customization argument
159
159
  plot_parser.add_argument(
160
- "-l",
161
- "--label_text",
160
+ '-l',
161
+ '--label_text',
162
162
  type=str,
163
- help=
164
- ("Custom label for the plot. You can use placeholders: %%decomposition%% "
165
- "(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), "
166
- "%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)"
167
- ),
168
- default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
163
+ help=(
164
+ 'Custom label for the plot. You can use placeholders: %%decomposition%% '
165
+ '(or %%p%%), %%precision%% (or %%pr%%), %%plot_name%% (or %%pn%%), '
166
+ '%%backend%% (or %%b%%), %%node%% (or %%n%%), %%methodname%% (or %%m%%)'
167
+ ),
168
+ default='%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
169
169
  )
170
170
 
171
171
  plot_parser.add_argument(
172
- "-xl",
173
- "--xlabel",
172
+ '-xl',
173
+ '--xlabel',
174
174
  type=str,
175
- help="X-axis label for the plot",
175
+ help='X-axis label for the plot',
176
176
  )
177
177
  plot_parser.add_argument(
178
- "-tl",
179
- "--title",
178
+ '-tl',
179
+ '--title',
180
180
  type=str,
181
- help="Title for the plot",
181
+ help='Title for the plot',
182
182
  )
183
183
 
184
- subparsers.add_parser("label_help", help="Label customization help")
184
+ subparsers.add_parser('label_help', help='Label customization help')
185
185
 
186
186
  args = parser.parse_args()
187
187
 
188
188
  # if command was plot, then check if pdim_strategy is validat
189
- if args.command == "plot":
190
- if "plot_all" in args.pdim_strategy and len(args.pdim_strategy) > 1:
189
+ if args.command == 'plot':
190
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
191
191
  print(
192
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
192
+ """
193
+ Warning: 'plot_all' strategy is combined with other strategies.
194
+ Using 'plot_all' only.
195
+ """
193
196
  )
194
- args.pdim_strategy = ["plot_all"]
197
+ args.pdim_strategy = ['plot_all']
195
198
 
196
- if "plot_fastest" in args.pdim_strategy and len(
197
- args.pdim_strategy) > 1:
199
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
198
200
  print(
199
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
201
+ """
202
+ Warning: 'plot_fastest' strategy is combined with other strategies.
203
+ Using 'plot_fastest' only.
204
+ """
200
205
  )
201
- args.pdim_strategy = ["plot_fastest"]
206
+ args.pdim_strategy = ['plot_fastest']
202
207
  if args.plot_times is not None:
203
208
  args.plot_columns = args.plot_times
204
209
  elif args.plot_memory is not None:
205
210
  args.plot_columns = args.plot_memory
206
211
  else:
207
- raise ValueError(
208
- "Either plot_times or plot_memory should be provided")
212
+ raise ValueError('Either plot_times or plot_memory should be provided')
213
+
214
+ # Note: for Weak scaling, plot_weak_scaling enforces that both gpus and
215
+ # data_size are provided and have matching lengths. For Strong and
216
+ # WeakFixed, gpus/data_size remain optional as before.
209
217
 
210
218
  return args
jax_hpc_profiler/main.py CHANGED
@@ -1,29 +1,26 @@
1
- import sys
2
- from typing import List, Optional
3
-
4
1
  from .create_argparse import create_argparser
5
- from .plotting import plot_strong_scaling, plot_weak_scaling
6
- from .utils import clean_up_csv, concatenate_csvs
2
+ from .plotting import plot_strong_scaling, plot_weak_fixed_scaling, plot_weak_scaling
3
+ from .utils import concatenate_csvs
7
4
 
8
5
 
9
6
  def main():
10
7
  args = create_argparser()
11
8
 
12
- if args.command == "concat":
9
+ if args.command == 'concat':
13
10
  input_dir, output_dir = args.input, args.output
14
11
  concatenate_csvs(input_dir, output_dir)
15
- elif args.command == "label_help":
16
- print(f"Customize the label text for the plot. using these commands.")
17
- print(" -- %m% or %methodname%: method name")
18
- print(" -- %f% or %function%: function name")
19
- print(" -- %pn% or %plot_name%: plot name")
20
- print(" -- %pr% or %precision%: precision")
21
- print(" -- %b% or %backend%: backend")
22
- print(" -- %p% or %pdims%: pdims")
23
- print(" -- %n% or %node%: node")
24
- elif args.command == "plot":
25
-
26
- if args.scaling.lower() == "weak" or args.scaling.lower() == "w":
12
+ elif args.command == 'label_help':
13
+ print('Customize the label text for the plot. using these commands.')
14
+ print(' -- %m% or %methodname%: method name')
15
+ print(' -- %f% or %function%: function name')
16
+ print(' -- %pn% or %plot_name%: plot name')
17
+ print(' -- %pr% or %precision%: precision')
18
+ print(' -- %b% or %backend%: backend')
19
+ print(' -- %p% or %pdims%: pdims')
20
+ print(' -- %n% or %node%: node')
21
+ elif args.command == 'plot':
22
+ scaling = args.scaling.lower()
23
+ if scaling in ('weak', 'w'):
27
24
  plot_weak_scaling(
28
25
  args.csv_files,
29
26
  args.gpus,
@@ -37,13 +34,15 @@ def main():
37
34
  args.plot_columns,
38
35
  args.memory_units,
39
36
  args.label_text,
37
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
40
38
  args.title,
41
- args.label_text,
42
39
  args.figure_size,
43
40
  args.dark_bg,
44
41
  args.output,
42
+ args.weak_ideal_line,
43
+ args.weak_reverse_axes,
45
44
  )
46
- elif args.scaling.lower() == "strong" or args.scaling.lower() == "s":
45
+ elif scaling in ('strong', 's'):
47
46
  plot_strong_scaling(
48
47
  args.csv_files,
49
48
  args.gpus,
@@ -57,13 +56,33 @@ def main():
57
56
  args.plot_columns,
58
57
  args.memory_units,
59
58
  args.label_text,
60
- args.title,
59
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Number of GPUs',
60
+ args.title if getattr(args, 'title', None) is not None else 'Data sizes',
61
+ args.figure_size,
62
+ args.dark_bg,
63
+ args.output,
64
+ )
65
+ elif scaling in ('weakfixed', 'wf'):
66
+ plot_weak_fixed_scaling(
67
+ args.csv_files,
68
+ args.gpus,
69
+ args.data_size,
70
+ args.function_name,
71
+ args.precision,
72
+ args.filter_pdims,
73
+ args.pdim_strategy,
74
+ args.print_decompositions,
75
+ args.backends,
76
+ args.plot_columns,
77
+ args.memory_units,
61
78
  args.label_text,
79
+ args.xlabel if getattr(args, 'xlabel', None) is not None else 'Data sizes',
80
+ args.title if getattr(args, 'title', None) is not None else 'Number of GPUs',
62
81
  args.figure_size,
63
82
  args.dark_bg,
64
83
  args.output,
65
84
  )
66
85
 
67
86
 
68
- if __name__ == "__main__":
87
+ if __name__ == '__main__':
69
88
  main()