jax-hpc-profiler 0.2.7__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/pyproject.toml +1 -1
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/plotting.py +10 -8
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/utils.py +2 -2
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/README.md +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/create_argparse.py +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/main.py +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler/timer.py +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -7,7 +7,7 @@ import pandas as pd
|
|
|
7
7
|
from matplotlib.axes import Axes
|
|
8
8
|
from matplotlib.patches import FancyBboxPatch
|
|
9
9
|
|
|
10
|
-
from .utils import inspect_df, plot_with_pdims_strategy
|
|
10
|
+
from .utils import inspect_df, plot_with_pdims_strategy, inspect_data
|
|
11
11
|
|
|
12
12
|
np.seterr(divide='ignore')
|
|
13
13
|
plt.rcParams.update({'font.size': 15})
|
|
@@ -118,13 +118,13 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
118
118
|
axs = axs.flatten()
|
|
119
119
|
else:
|
|
120
120
|
axs = [axs]
|
|
121
|
-
|
|
121
|
+
|
|
122
122
|
for i, fixed_size in enumerate(fixed_sizes):
|
|
123
123
|
ax: Axes = axs[i]
|
|
124
124
|
|
|
125
|
+
x_values = []
|
|
126
|
+
y_values = []
|
|
125
127
|
for method, df in dataframes.items():
|
|
126
|
-
x_values = []
|
|
127
|
-
y_values = []
|
|
128
128
|
|
|
129
129
|
filtered_method_df = df[df[fixed_column] == int(fixed_size)]
|
|
130
130
|
if filtered_method_df.empty:
|
|
@@ -137,6 +137,7 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
137
137
|
plot_columns)
|
|
138
138
|
|
|
139
139
|
for backend, precision, function, plot_column in combinations:
|
|
140
|
+
|
|
140
141
|
filtered_params_df = filtered_method_df[
|
|
141
142
|
(filtered_method_df['backend'] == backend)
|
|
142
143
|
& (filtered_method_df['precision'] == precision) &
|
|
@@ -149,10 +150,11 @@ def plot_scaling(dataframes: Dict[str, pd.DataFrame],
|
|
|
149
150
|
|
|
150
151
|
x_values.extend(x_vals)
|
|
151
152
|
y_values.extend(y_vals)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
153
|
+
|
|
154
|
+
if len(x_values) != 0:
|
|
155
|
+
plotting_memory = 'time' not in plot_columns[0].lower()
|
|
156
|
+
configure_axes(ax, x_values, y_values, f"{title} {fixed_size}", xlabel,
|
|
157
|
+
plotting_memory, memory_units)
|
|
156
158
|
|
|
157
159
|
for i in range(num_subplots, num_rows * num_cols):
|
|
158
160
|
fig.delaxes(axs[i])
|
|
@@ -336,12 +336,11 @@ def clean_up_csv(
|
|
|
336
336
|
# Filter data sizes
|
|
337
337
|
if data_sizes:
|
|
338
338
|
df = df[df['x'].isin(data_sizes)]
|
|
339
|
-
|
|
339
|
+
|
|
340
340
|
# Filter pdims
|
|
341
341
|
if pdims:
|
|
342
342
|
px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
|
|
343
343
|
df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
|
|
344
|
-
|
|
345
344
|
# convert memory units columns to remquested memory_units
|
|
346
345
|
match memory_units:
|
|
347
346
|
case 'KB':
|
|
@@ -385,6 +384,7 @@ def clean_up_csv(
|
|
|
385
384
|
df.drop(columns=['px', 'py'], inplace=True)
|
|
386
385
|
if not 'plot_all' in pdims_strategy:
|
|
387
386
|
df = df[df['decomp'].isin(pdims_strategy)]
|
|
387
|
+
|
|
388
388
|
# check available gpus in dataset
|
|
389
389
|
available_gpu_counts.update(df['gpus'].unique())
|
|
390
390
|
available_data_sizes.update(df['x'].unique())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.7 → jax_hpc_profiler-0.2.8}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|