jax-hpc-profiler 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/pyproject.toml +1 -1
  3. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/create_argparse.py +2 -1
  4. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/main.py +13 -5
  5. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/plotting.py +13 -10
  6. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/timer.py +2 -10
  7. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/utils.py +3 -3
  8. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  9. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/LICENSE +0 -0
  10. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/README.md +0 -0
  11. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/setup.cfg +0 -0
  12. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/__init__.py +0 -0
  13. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.6 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.6"
7
+ version = "0.2.8"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -67,7 +67,6 @@ def create_argparser():
67
67
  plot_parser.add_argument('-fn',
68
68
  '--function_name',
69
69
  nargs='+',
70
- default=['FFT'],
71
70
  help='Function names to filter')
72
71
 
73
72
  # Time or memory related arguments
@@ -135,6 +134,8 @@ def create_argparser():
135
134
  'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
136
135
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
136
 
137
+ subparsers.add_parser('label_help',help='Label customization help')
138
+
138
139
  args = parser.parse_args()
139
140
 
140
141
  # if command was plot, then check if pdim_strategy is validat
@@ -11,6 +11,15 @@ def main():
11
11
  if args.command == 'concat':
12
12
  input_dir, output_dir = args.input, args.output
13
13
  concatenate_csvs(input_dir, output_dir)
14
+ elif args.command == 'label_help':
15
+ print(f"Customize the label text for the plot. using these commands.")
16
+ print(' -- %m% or %methodname%: method name')
17
+ print(' -- %f% or %function%: function name')
18
+ print(' -- %pn% or %plot_name%: plot name')
19
+ print(' -- %pr% or %precision%: precision')
20
+ print(' -- %b% or %backend%: backend')
21
+ print(' -- %p% or %pdims%: pdims')
22
+ print(' -- %n% or %node%: node')
14
23
  elif args.command == 'plot':
15
24
  dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
16
25
  args.csv_files, args.precision, args.function_name, args.gpus,
@@ -23,11 +32,10 @@ def main():
23
32
  f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
24
33
  )
25
34
  # filter back the requested data sizes and gpus
26
- args.gpus = [gpu for gpu in args.gpus if gpu in available_gpu_counts]
27
- args.data_size = [
28
- data_size for data_size in args.data_size
29
- if data_size in available_data_sizes
30
- ]
35
+
36
+ args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
37
+ args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
38
+
31
39
  if len(args.gpus) == 0:
32
40
  print(f"No dataframes found for the given GPUs. Exiting...")
33
41
  sys.exit(1)
@@ -7,10 +7,10 @@ import pandas as pd
7
7
  from matplotlib.axes import Axes
8
8
  from matplotlib.patches import FancyBboxPatch
9
9
 
10
- from .utils import inspect_df, plot_with_pdims_strategy
10
+ from .utils import inspect_df, plot_with_pdims_strategy, inspect_data
11
11
 
12
12
  np.seterr(divide='ignore')
13
- plt.rcParams.update({'font.size': 10})
13
+ plt.rcParams.update({'font.size': 15})
14
14
 
15
15
 
16
16
  def configure_axes(ax: Axes,
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
38
38
  f2 = lambda x: np.log2(x)
39
39
  g2 = lambda x: 2**x
40
40
  ax.set_xlim([min(x_values), max(x_values)])
41
- y_min, y_max = min(y_values) * 0.9, max(y_values) * 1.1
41
+ y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
42
42
  ax.set_title(title)
43
43
  ax.set_ylim([y_min, y_max])
44
44
  ax.set_xscale('function', functions=(f2, g2))
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
57
57
  ax.legend(loc='lower center',
58
58
  bbox_to_anchor=(0.5, 0.05),
59
59
  ncol=4,
60
+ fontsize="x-large",
60
61
  prop={'size': 14})
61
62
 
62
63
 
@@ -117,13 +118,13 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
117
118
  axs = axs.flatten()
118
119
  else:
119
120
  axs = [axs]
120
-
121
+
121
122
  for i, fixed_size in enumerate(fixed_sizes):
122
123
  ax: Axes = axs[i]
123
124
 
125
+ x_values = []
126
+ y_values = []
124
127
  for method, df in dataframes.items():
125
- x_values = []
126
- y_values = []
127
128
 
128
129
  filtered_method_df = df[df[fixed_column] == int(fixed_size)]
129
130
  if filtered_method_df.empty:
@@ -136,6 +137,7 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
136
137
  plot_columns)
137
138
 
138
139
  for backend, precision, function, plot_column in combinations:
140
+
139
141
  filtered_params_df = filtered_method_df[
140
142
  (filtered_method_df['backend'] == backend)
141
143
  & (filtered_method_df['precision'] == precision) &
@@ -148,10 +150,11 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
148
150
 
149
151
  x_values.extend(x_vals)
150
152
  y_values.extend(y_vals)
151
-
152
- plotting_memory = 'time' not in plot_columns[0].lower()
153
- configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
154
- plotting_memory, memory_units)
153
+
154
+ if len(x_values) != 0:
155
+ plotting_memory = 'time' not in plot_columns[0].lower()
156
+ configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
157
+ plotting_memory, memory_units)
155
158
 
156
159
  for i in range(num_subplots, num_rows * num_cols):
157
160
  fig.delaxes(axs[i])
@@ -23,16 +23,11 @@ class Timer:
23
23
  self.compiled_code = {}
24
24
  self.save_jaxpr = save_jaxpr
25
25
 
26
- def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
- if cost_analysis is None:
28
- return None
29
- return cost_analysis[0]['flops']
30
-
31
26
  def _normalize_memory_units(self, memory_analysis) -> str:
32
27
 
33
28
  sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
29
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
- factor = int(np.log10(memory_analysis) // 3)
30
+ factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
36
31
 
37
32
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
33
 
@@ -62,11 +57,9 @@ class Timer:
62
57
  compiled = lowered.compile()
63
58
  memory_analysis = self._read_memory_analysis(
64
59
  compiled.memory_analysis())
65
- cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
66
60
 
67
61
  self.compiled_code["LOWERED"] = lowered.as_text()
68
62
  self.compiled_code["COMPILED"] = compiled.as_text()
69
- self.profiling_data["FLOPS"] = cost_analysis
70
63
  self.profiling_data["generated_code"] = memory_analysis[0]
71
64
  self.profiling_data["argument_size"] = memory_analysis[1]
72
65
  self.profiling_data["output_size"] = memory_analysis[2]
@@ -145,7 +138,6 @@ class Timer:
145
138
  std_time = np.std(times_array)
146
139
  last_time = times_array[-1]
147
140
 
148
- flops = self.profiling_data["FLOPS"]
149
141
  generated_code = self.profiling_data["generated_code"]
150
142
  argument_size = self.profiling_data["argument_size"]
151
143
  output_size = self.profiling_data["output_size"]
@@ -154,7 +146,7 @@ class Timer:
154
146
  csv_line = (
155
147
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
156
148
  f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
157
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
149
+ f"{generated_code},{argument_size},{output_size},{temp_size}\n"
158
150
  )
159
151
 
160
152
  with open(csv_filename, 'a') as f:
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
123
123
  # Sort all and keep fastest
124
124
  sorted_dfs = []
125
125
  for _, group in df_decomp:
126
- group.sort_values(by=[y_col], inplace=True, ascending=False)
126
+ group.sort_values(by=[y_col], inplace=True, ascending=True)
127
127
  sorted_dfs.append(group.iloc[0])
128
128
  sorted_df = pd.DataFrame(sorted_dfs)
129
129
  label_params.update({
@@ -336,12 +336,11 @@ def clean_up_csv(
336
336
  # Filter data sizes
337
337
  if data_sizes:
338
338
  df = df[df['x'].isin(data_sizes)]
339
-
339
+
340
340
  # Filter pdims
341
341
  if pdims:
342
342
  px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
343
343
  df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
344
-
345
344
  # convert memory units columns to remquested memory_units
346
345
  match memory_units:
347
346
  case 'KB':
@@ -385,6 +384,7 @@ def clean_up_csv(
385
384
  df.drop(columns=['px', 'py'], inplace=True)
386
385
  if not 'plot_all' in pdims_strategy:
387
386
  df = df[df['decomp'].isin(pdims_strategy)]
387
+
388
388
  # check available gpus in dataset
389
389
  available_gpu_counts.update(df['gpus'].unique())
390
390
  available_data_sizes.update(df['x'].unique())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE