jax-hpc-profiler 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -67,7 +67,6 @@ def create_argparser():
67
67
  plot_parser.add_argument('-fn',
68
68
  '--function_name',
69
69
  nargs='+',
70
- default=['FFT'],
71
70
  help='Function names to filter')
72
71
 
73
72
  # Time or memory related arguments
@@ -135,24 +134,28 @@ def create_argparser():
135
134
  'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
136
135
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
136
 
138
- args = parser.parse_args()
137
+ subparsers.add_parser('label_help',help='Label customization help')
139
138
 
140
- if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
141
- print(
142
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
143
- )
144
- args.pdim_strategy = ['plot_all']
145
-
146
- if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
147
- print(
148
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
149
- )
150
- args.pdim_strategy = ['plot_fastest']
151
- if args.plot_times is not None:
152
- args.plot_columns = args.plot_times
153
- elif args.plot_memory is not None:
154
- args.plot_columns = args.plot_memory
155
- else:
156
- raise ValueError('Either plot_times or plot_memory should be provided')
139
+ args = parser.parse_args()
140
+
141
+ # if command was plot, then check if pdim_strategy is validat
142
+ if args.command == 'plot':
143
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
144
+ print(
145
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
146
+ )
147
+ args.pdim_strategy = ['plot_all']
148
+
149
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
150
+ print(
151
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
152
+ )
153
+ args.pdim_strategy = ['plot_fastest']
154
+ if args.plot_times is not None:
155
+ args.plot_columns = args.plot_times
156
+ elif args.plot_memory is not None:
157
+ args.plot_columns = args.plot_memory
158
+ else:
159
+ raise ValueError('Either plot_times or plot_memory should be provided')
157
160
 
158
161
  return args
jax_hpc_profiler/main.py CHANGED
@@ -11,6 +11,15 @@ def main():
11
11
  if args.command == 'concat':
12
12
  input_dir, output_dir = args.input, args.output
13
13
  concatenate_csvs(input_dir, output_dir)
14
+ elif args.command == 'label_help':
15
+ print(f"Customize the label text for the plot. using these commands.")
16
+ print(' -- %m% or %methodname%: method name')
17
+ print(' -- %f% or %function%: function name')
18
+ print(' -- %pn% or %plot_name%: plot name')
19
+ print(' -- %pr% or %precision%: precision')
20
+ print(' -- %b% or %backend%: backend')
21
+ print(' -- %p% or %pdims%: pdims')
22
+ print(' -- %n% or %node%: node')
14
23
  elif args.command == 'plot':
15
24
  dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
16
25
  args.csv_files, args.precision, args.function_name, args.gpus,
@@ -23,11 +32,10 @@ def main():
23
32
  f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
24
33
  )
25
34
  # filter back the requested data sizes and gpus
26
- args.gpus = [gpu for gpu in args.gpus if gpu in available_gpu_counts]
27
- args.data_size = [
28
- data_size for data_size in args.data_size
29
- if data_size in available_data_sizes
30
- ]
35
+
36
+ args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
37
+ args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
38
+
31
39
  if len(args.gpus) == 0:
32
40
  print(f"No dataframes found for the given GPUs. Exiting...")
33
41
  sys.exit(1)
@@ -10,7 +10,7 @@ from matplotlib.patches import FancyBboxPatch
10
10
  from .utils import inspect_df, plot_with_pdims_strategy
11
11
 
12
12
  np.seterr(divide='ignore')
13
- plt.rcParams.update({'font.size': 10})
13
+ plt.rcParams.update({'font.size': 15})
14
14
 
15
15
 
16
16
  def configure_axes(ax: Axes,
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
38
38
  f2 = lambda x: np.log2(x)
39
39
  g2 = lambda x: 2**x
40
40
  ax.set_xlim([min(x_values), max(x_values)])
41
- y_min, y_max = min(y_values) * 0.9, max(y_values) * 1.1
41
+ y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
42
42
  ax.set_title(title)
43
43
  ax.set_ylim([y_min, y_max])
44
44
  ax.set_xscale('function', functions=(f2, g2))
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
57
57
  ax.legend(loc='lower center',
58
58
  bbox_to_anchor=(0.5, 0.05),
59
59
  ncol=4,
60
+ fontsize="x-large",
60
61
  prop={'size': 14})
61
62
 
62
63
 
jax_hpc_profiler/timer.py CHANGED
@@ -23,16 +23,11 @@ class Timer:
23
23
  self.compiled_code = {}
24
24
  self.save_jaxpr = save_jaxpr
25
25
 
26
- def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
- if cost_analysis is None:
28
- return None
29
- return cost_analysis[0]['flops']
30
-
31
26
  def _normalize_memory_units(self, memory_analysis) -> str:
32
27
 
33
28
  sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
29
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
- factor = int(np.log10(memory_analysis) // 3)
30
+ factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
36
31
 
37
32
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
33
 
@@ -62,11 +57,9 @@ class Timer:
62
57
  compiled = lowered.compile()
63
58
  memory_analysis = self._read_memory_analysis(
64
59
  compiled.memory_analysis())
65
- cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
66
60
 
67
61
  self.compiled_code["LOWERED"] = lowered.as_text()
68
62
  self.compiled_code["COMPILED"] = compiled.as_text()
69
- self.profiling_data["FLOPS"] = cost_analysis
70
63
  self.profiling_data["generated_code"] = memory_analysis[0]
71
64
  self.profiling_data["argument_size"] = memory_analysis[1]
72
65
  self.profiling_data["output_size"] = memory_analysis[2]
@@ -138,13 +131,13 @@ class Timer:
138
131
 
139
132
  times_array = self._get_mean_times()
140
133
  if jax.process_index() == 0:
134
+
141
135
  min_time = np.min(times_array)
142
136
  max_time = np.max(times_array)
143
137
  mean_time = np.mean(times_array)
144
138
  std_time = np.std(times_array)
145
139
  last_time = times_array[-1]
146
140
 
147
- flops = self.profiling_data["FLOPS"]
148
141
  generated_code = self.profiling_data["generated_code"]
149
142
  argument_size = self.profiling_data["argument_size"]
150
143
  output_size = self.profiling_data["output_size"]
@@ -153,7 +146,7 @@ class Timer:
153
146
  csv_line = (
154
147
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
155
148
  f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
156
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
149
+ f"{generated_code},{argument_size},{output_size},{temp_size}\n"
157
150
  )
158
151
 
159
152
  with open(csv_filename, 'a') as f:
jax_hpc_profiler/utils.py CHANGED
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
123
123
  # Sort all and keep fastest
124
124
  sorted_dfs = []
125
125
  for _, group in df_decomp:
126
- group.sort_values(by=[y_col], inplace=True, ascending=False)
126
+ group.sort_values(by=[y_col], inplace=True, ascending=True)
127
127
  sorted_dfs.append(group.iloc[0])
128
128
  sorted_df = pd.DataFrame(sorted_dfs)
129
129
  label_params.update({
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
3
+ jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
4
+ jax_hpc_profiler/plotting.py,sha256=cwHznCZ2pF2J7AtyUOB3pASnahKBLRWHAPGXmGDvWas,8360
5
+ jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
6
+ jax_hpc_profiler/utils.py,sha256=okWQUJHblUKkYnw7j7wJ75PSbhVItXKkTMKjj0BmgR0,14132
7
+ jax_hpc_profiler-0.2.7.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.7.dist-info/METADATA,sha256=bQkpy5Kr8ybEM7GU7qR0FEnDV7xsLbrq98GRDfgDTQU,49250
9
+ jax_hpc_profiler-0.2.7.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
10
+ jax_hpc_profiler-0.2.7.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.7.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
- jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
- jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=r4Mw2tC82cxvkMPkIy8BuZjKikgxn6cviEgmu6rpC9o,8616
6
- jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.5.dist-info/METADATA,sha256=6Vk6fA1nz-m8ZZVzSZPsg9xR9iJkFJ01x36pddG0RAM,49250
9
- jax_hpc_profiler-0.2.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.5.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.5.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.5.dist-info/RECORD,,