jax-hpc-profiler 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  from matplotlib.axes import Axes
8
8
  from matplotlib.patches import FancyBboxPatch
9
9
 
10
- from .utils import inspect_df, plot_with_pdims_strategy
10
+ from .utils import inspect_df, plot_with_pdims_strategy, inspect_data
11
11
 
12
12
  np.seterr(divide='ignore')
13
13
  plt.rcParams.update({'font.size': 15})
@@ -118,13 +118,13 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
118
118
  axs = axs.flatten()
119
119
  else:
120
120
  axs = [axs]
121
-
121
+
122
122
  for i, fixed_size in enumerate(fixed_sizes):
123
123
  ax: Axes = axs[i]
124
124
 
125
+ x_values = []
126
+ y_values = []
125
127
  for method, df in dataframes.items():
126
- x_values = []
127
- y_values = []
128
128
 
129
129
  filtered_method_df = df[df[fixed_column] == int(fixed_size)]
130
130
  if filtered_method_df.empty:
@@ -137,6 +137,7 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
137
137
  plot_columns)
138
138
 
139
139
  for backend, precision, function, plot_column in combinations:
140
+
140
141
  filtered_params_df = filtered_method_df[
141
142
  (filtered_method_df['backend'] == backend)
142
143
  & (filtered_method_df['precision'] == precision) &
@@ -149,10 +150,11 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
149
150
 
150
151
  x_values.extend(x_vals)
151
152
  y_values.extend(y_vals)
152
-
153
- plotting_memory = 'time' not in plot_columns[0].lower()
154
- configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
155
- plotting_memory, memory_units)
153
+
154
+ if len(x_values) != 0:
155
+ plotting_memory = 'time' not in plot_columns[0].lower()
156
+ configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
157
+ plotting_memory, memory_units)
156
158
 
157
159
  for i in range(num_subplots, num_rows * num_cols):
158
160
  fig.delaxes(axs[i])
jax_hpc_profiler/utils.py CHANGED
@@ -336,12 +336,11 @@ def clean_up_csv(
336
336
  # Filter data sizes
337
337
  if data_sizes:
338
338
  df = df[df['x'].isin(data_sizes)]
339
-
339
+
340
340
  # Filter pdims
341
341
  if pdims:
342
342
  px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
343
343
  df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
344
-
345
344
  # convert memory units columns to remquested memory_units
346
345
  match memory_units:
347
346
  case 'KB':
@@ -385,6 +384,7 @@ def clean_up_csv(
385
384
  df.drop(columns=['px', 'py'], inplace=True)
386
385
  if not 'plot_all' in pdims_strategy:
387
386
  df = df[df['decomp'].isin(pdims_strategy)]
387
+
388
388
  # check available gpus in dataset
389
389
  available_gpu_counts.update(df['gpus'].unique())
390
390
  available_data_sizes.update(df['x'].unique())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
3
+ jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
4
+ jax_hpc_profiler/plotting.py,sha256=jnWpyfZOB5w5L4BNIU5euKbUt__wfGPyCjUITrEwScM,8431
5
+ jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
6
+ jax_hpc_profiler/utils.py,sha256=tUXnNHwQSSCqA6XxLd4MoV2gyFIC7ncB-uvVc2INhms,14140
7
+ jax_hpc_profiler-0.2.8.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.8.dist-info/METADATA,sha256=NDOKx7RzKr4Z4mJnY457-4uxaP2TDhwWYH5rZRxfukQ,49250
9
+ jax_hpc_profiler-0.2.8.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
10
+ jax_hpc_profiler-0.2.8.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.8.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (74.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
3
- jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
4
- jax_hpc_profiler/plotting.py,sha256=cwHznCZ2pF2J7AtyUOB3pASnahKBLRWHAPGXmGDvWas,8360
5
- jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
6
- jax_hpc_profiler/utils.py,sha256=okWQUJHblUKkYnw7j7wJ75PSbhVItXKkTMKjj0BmgR0,14132
7
- jax_hpc_profiler-0.2.7.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.7.dist-info/METADATA,sha256=bQkpy5Kr8ybEM7GU7qR0FEnDV7xsLbrq98GRDfgDTQU,49250
9
- jax_hpc_profiler-0.2.7.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
10
- jax_hpc_profiler-0.2.7.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.7.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.7.dist-info/RECORD,,