jax-hpc-profiler 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -67,7 +67,6 @@ def create_argparser():
67
67
  plot_parser.add_argument('-fn',
68
68
  '--function_name',
69
69
  nargs='+',
70
- default=['FFT'],
71
70
  help='Function names to filter')
72
71
 
73
72
  # Time or memory related arguments
@@ -135,6 +134,8 @@ def create_argparser():
135
134
  'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
136
135
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
136
 
137
+ subparsers.add_parser('label_help',help='Label customization help')
138
+
138
139
  args = parser.parse_args()
139
140
 
140
141
  # if command was plot, then check if pdim_strategy is validat
jax_hpc_profiler/main.py CHANGED
@@ -11,6 +11,15 @@ def main():
11
11
  if args.command == 'concat':
12
12
  input_dir, output_dir = args.input, args.output
13
13
  concatenate_csvs(input_dir, output_dir)
14
+ elif args.command == 'label_help':
15
+ print(f"Customize the label text for the plot. using these commands.")
16
+ print(' -- %m% or %methodname%: method name')
17
+ print(' -- %f% or %function%: function name')
18
+ print(' -- %pn% or %plot_name%: plot name')
19
+ print(' -- %pr% or %precision%: precision')
20
+ print(' -- %b% or %backend%: backend')
21
+ print(' -- %p% or %pdims%: pdims')
22
+ print(' -- %n% or %node%: node')
14
23
  elif args.command == 'plot':
15
24
  dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
16
25
  args.csv_files, args.precision, args.function_name, args.gpus,
@@ -23,11 +32,10 @@ def main():
23
32
  f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
24
33
  )
25
34
  # filter back the requested data sizes and gpus
26
- args.gpus = [gpu for gpu in args.gpus if gpu in available_gpu_counts]
27
- args.data_size = [
28
- data_size for data_size in args.data_size
29
- if data_size in available_data_sizes
30
- ]
35
+
36
+ args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
37
+ args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
38
+
31
39
  if len(args.gpus) == 0:
32
40
  print(f"No dataframes found for the given GPUs. Exiting...")
33
41
  sys.exit(1)
@@ -10,7 +10,7 @@ from matplotlib.patches import FancyBboxPatch
10
10
  from .utils import inspect_df, plot_with_pdims_strategy
11
11
 
12
12
  np.seterr(divide='ignore')
13
- plt.rcParams.update({'font.size': 10})
13
+ plt.rcParams.update({'font.size': 15})
14
14
 
15
15
 
16
16
  def configure_axes(ax: Axes,
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
38
38
  f2 = lambda x: np.log2(x)
39
39
  g2 = lambda x: 2**x
40
40
  ax.set_xlim([min(x_values), max(x_values)])
41
- y_min, y_max = min(y_values) * 0.9, max(y_values) * 1.1
41
+ y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
42
42
  ax.set_title(title)
43
43
  ax.set_ylim([y_min, y_max])
44
44
  ax.set_xscale('function', functions=(f2, g2))
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
57
57
  ax.legend(loc='lower center',
58
58
  bbox_to_anchor=(0.5, 0.05),
59
59
  ncol=4,
60
+ fontsize="x-large",
60
61
  prop={'size': 14})
61
62
 
62
63
 
jax_hpc_profiler/timer.py CHANGED
@@ -23,16 +23,11 @@ class Timer:
23
23
  self.compiled_code = {}
24
24
  self.save_jaxpr = save_jaxpr
25
25
 
26
- def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
- if cost_analysis is None:
28
- return None
29
- return cost_analysis[0]['flops']
30
-
31
26
  def _normalize_memory_units(self, memory_analysis) -> str:
32
27
 
33
28
  sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
29
  factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
- factor = int(np.log10(memory_analysis) // 3)
30
+ factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
36
31
 
37
32
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
33
 
@@ -62,11 +57,9 @@ class Timer:
62
57
  compiled = lowered.compile()
63
58
  memory_analysis = self._read_memory_analysis(
64
59
  compiled.memory_analysis())
65
- cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
66
60
 
67
61
  self.compiled_code["LOWERED"] = lowered.as_text()
68
62
  self.compiled_code["COMPILED"] = compiled.as_text()
69
- self.profiling_data["FLOPS"] = cost_analysis
70
63
  self.profiling_data["generated_code"] = memory_analysis[0]
71
64
  self.profiling_data["argument_size"] = memory_analysis[1]
72
65
  self.profiling_data["output_size"] = memory_analysis[2]
@@ -145,7 +138,6 @@ class Timer:
145
138
  std_time = np.std(times_array)
146
139
  last_time = times_array[-1]
147
140
 
148
- flops = self.profiling_data["FLOPS"]
149
141
  generated_code = self.profiling_data["generated_code"]
150
142
  argument_size = self.profiling_data["argument_size"]
151
143
  output_size = self.profiling_data["output_size"]
@@ -154,7 +146,7 @@ class Timer:
154
146
  csv_line = (
155
147
  f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
156
148
  f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
157
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
149
+ f"{generated_code},{argument_size},{output_size},{temp_size}\n"
158
150
  )
159
151
 
160
152
  with open(csv_filename, 'a') as f:
jax_hpc_profiler/utils.py CHANGED
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
123
123
  # Sort all and keep fastest
124
124
  sorted_dfs = []
125
125
  for _, group in df_decomp:
126
- group.sort_values(by=[y_col], inplace=True, ascending=False)
126
+ group.sort_values(by=[y_col], inplace=True, ascending=True)
127
127
  sorted_dfs.append(group.iloc[0])
128
128
  sorted_df = pd.DataFrame(sorted_dfs)
129
129
  label_params.update({
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
3
+ jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
4
+ jax_hpc_profiler/plotting.py,sha256=cwHznCZ2pF2J7AtyUOB3pASnahKBLRWHAPGXmGDvWas,8360
5
+ jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
6
+ jax_hpc_profiler/utils.py,sha256=okWQUJHblUKkYnw7j7wJ75PSbhVItXKkTMKjj0BmgR0,14132
7
+ jax_hpc_profiler-0.2.7.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.7.dist-info/METADATA,sha256=bQkpy5Kr8ybEM7GU7qR0FEnDV7xsLbrq98GRDfgDTQU,49250
9
+ jax_hpc_profiler-0.2.7.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
10
+ jax_hpc_profiler-0.2.7.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.7.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (73.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=dEicamRYqJ6GGdgcph2bwAbmdxPkS4tS12xZ4c0X_Pk,6484
3
- jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
- jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=j6oH5IZz12VJik2cE7EQ3a9tAW9C8xl7D2QLW8Bkz3s,8617
6
- jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.6.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.6.dist-info/METADATA,sha256=AgLXyb89gdxgeyDv02_P5oqIvoffdTh1mc3zFPUVBAU,49250
9
- jax_hpc_profiler-0.2.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.6.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.6.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.6.dist-info/RECORD,,