jax-hpc-profiler 0.2.9__tar.gz → 0.2.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/PKG-INFO +2 -2
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/pyproject.toml +1 -1
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/plotting.py +14 -9
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/timer.py +17 -1
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/utils.py +3 -2
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/PKG-INFO +2 -2
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/README.md +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/create_argparse.py +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler/main.py +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -77,9 +77,9 @@ def plot_scaling(
|
|
|
77
77
|
output: Optional[str] = None,
|
|
78
78
|
dark_bg: bool = False,
|
|
79
79
|
print_decompositions: bool = False,
|
|
80
|
-
backends: List[str] =
|
|
81
|
-
precisions: List[str] =
|
|
82
|
-
functions: List[str]
|
|
80
|
+
backends: Optional[List[str]] = None,
|
|
81
|
+
precisions: Optional[List[str]] = None,
|
|
82
|
+
functions: Optional[List[str]] = None,
|
|
83
83
|
plot_columns: List[str] = ["mean_time"],
|
|
84
84
|
memory_units: str = "bytes",
|
|
85
85
|
label_text: str = "plot",
|
|
@@ -141,6 +141,11 @@ def plot_scaling(
|
|
|
141
141
|
by=[size_column])
|
|
142
142
|
functions = (pd.unique(filtered_method_df["function"])
|
|
143
143
|
if functions is None else functions)
|
|
144
|
+
precisions = (pd.unique(filtered_method_df["precision"])
|
|
145
|
+
if precisions is None else precisions)
|
|
146
|
+
backends = (pd.unique(filtered_method_df["backend"])
|
|
147
|
+
if backends is None else backends)
|
|
148
|
+
|
|
144
149
|
combinations = product(backends, precisions, functions,
|
|
145
150
|
plot_columns)
|
|
146
151
|
|
|
@@ -199,12 +204,12 @@ def plot_strong_scaling(
|
|
|
199
204
|
csv_files: List[str],
|
|
200
205
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
201
206
|
fixed_data_size: Optional[List[int]] = None,
|
|
202
|
-
functions: List[str]
|
|
203
|
-
precisions: List[str] =
|
|
207
|
+
functions: Optional[List[str]] = None,
|
|
208
|
+
precisions: Optional[List[str]] = None,
|
|
204
209
|
pdims: Optional[List[str]] = None,
|
|
205
210
|
pdims_strategy: List[str] = ["plot_fastest"],
|
|
206
211
|
print_decompositions: bool = False,
|
|
207
|
-
backends: List[str] =
|
|
212
|
+
backends: Optional[List[str]] = None,
|
|
208
213
|
plot_columns: List[str] = ["mean_time"],
|
|
209
214
|
memory_units: str = "bytes",
|
|
210
215
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
@@ -256,12 +261,12 @@ def plot_weak_scaling(
|
|
|
256
261
|
csv_files: List[str],
|
|
257
262
|
fixed_gpu_size: Optional[List[int]] = None,
|
|
258
263
|
fixed_data_size: Optional[List[int]] = None,
|
|
259
|
-
functions: List[str]
|
|
260
|
-
precisions: List[str] =
|
|
264
|
+
functions: Optional[List[str]] = None,
|
|
265
|
+
precisions: Optional[List[str]] = None,
|
|
261
266
|
pdims: Optional[List[str]] = None,
|
|
262
267
|
pdims_strategy: List[str] = ["plot_fastest"],
|
|
263
268
|
print_decompositions: bool = False,
|
|
264
|
-
backends: List[str] =
|
|
269
|
+
backends: Optional[List[str]] = None,
|
|
265
270
|
plot_columns: List[str] = ["mean_time"],
|
|
266
271
|
memory_units: str = "bytes",
|
|
267
272
|
label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable, List, Tuple
|
|
4
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
@@ -133,8 +133,12 @@ class Timer:
|
|
|
133
133
|
backend: str = "NCCL",
|
|
134
134
|
nodes: int = 1,
|
|
135
135
|
md_filename: str | None = None,
|
|
136
|
+
npz_data: Optional[dict] = None,
|
|
136
137
|
extra_info: dict = {},
|
|
137
138
|
):
|
|
139
|
+
if self.jit_time == 0.0 and len(self.times) == 0:
|
|
140
|
+
print(f"No profiling data to report for {function}")
|
|
141
|
+
return
|
|
138
142
|
|
|
139
143
|
if md_filename is None:
|
|
140
144
|
dirname, filename = (
|
|
@@ -147,6 +151,18 @@ class Timer:
|
|
|
147
151
|
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
148
152
|
)
|
|
149
153
|
|
|
154
|
+
if npz_data is not None:
|
|
155
|
+
dirname, filename = (
|
|
156
|
+
os.path.dirname(csv_filename),
|
|
157
|
+
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
158
|
+
)
|
|
159
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
160
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
161
|
+
npz_filename = (
|
|
162
|
+
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
|
|
163
|
+
)
|
|
164
|
+
np.savez(npz_filename, **npz_data)
|
|
165
|
+
|
|
150
166
|
y = x if y is None else y
|
|
151
167
|
z = x if z is None else z
|
|
152
168
|
|
|
@@ -248,7 +248,7 @@ def clean_up_csv(
|
|
|
248
248
|
data_sizes: Optional[List[int]] = None,
|
|
249
249
|
pdims: Optional[List[str]] = None,
|
|
250
250
|
pdims_strategy: List[str] = ['plot_fastest'],
|
|
251
|
-
backends: List[str] =
|
|
251
|
+
backends: Optional[List[str]] = None,
|
|
252
252
|
memory_units: str = 'KB',
|
|
253
253
|
) -> Tuple[Dict[str, pd.DataFrame], List[int], List[int]]:
|
|
254
254
|
"""
|
|
@@ -331,7 +331,8 @@ def clean_up_csv(
|
|
|
331
331
|
if function_names:
|
|
332
332
|
df = df[df['function'].isin(function_names)]
|
|
333
333
|
# Filter backends
|
|
334
|
-
|
|
334
|
+
if backends:
|
|
335
|
+
df = df[df['backend'].isin(backends)]
|
|
335
336
|
|
|
336
337
|
# Filter data sizes
|
|
337
338
|
if data_sizes:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.9 → jax_hpc_profiler-0.2.10}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|