jax-hpc-profiler 0.2.5__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/pyproject.toml +1 -1
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/create_argparse.py +22 -19
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/main.py +13 -5
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/plotting.py +3 -2
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/timer.py +3 -10
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/utils.py +1 -1
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/README.md +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -67,7 +67,6 @@ def create_argparser():
|
|
|
67
67
|
plot_parser.add_argument('-fn',
|
|
68
68
|
'--function_name',
|
|
69
69
|
nargs='+',
|
|
70
|
-
default=['FFT'],
|
|
71
70
|
help='Function names to filter')
|
|
72
71
|
|
|
73
72
|
# Time or memory related arguments
|
|
@@ -135,24 +134,28 @@ def create_argparser():
|
|
|
135
134
|
'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
|
|
136
135
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
|
|
137
136
|
|
|
138
|
-
|
|
137
|
+
subparsers.add_parser('label_help',help='Label customization help')
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
139
|
+
args = parser.parse_args()
|
|
140
|
+
|
|
141
|
+
# if command was plot, then check if pdim_strategy is validat
|
|
142
|
+
if args.command == 'plot':
|
|
143
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
144
|
+
print(
|
|
145
|
+
"Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
|
|
146
|
+
)
|
|
147
|
+
args.pdim_strategy = ['plot_all']
|
|
148
|
+
|
|
149
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
150
|
+
print(
|
|
151
|
+
"Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
|
|
152
|
+
)
|
|
153
|
+
args.pdim_strategy = ['plot_fastest']
|
|
154
|
+
if args.plot_times is not None:
|
|
155
|
+
args.plot_columns = args.plot_times
|
|
156
|
+
elif args.plot_memory is not None:
|
|
157
|
+
args.plot_columns = args.plot_memory
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
157
160
|
|
|
158
161
|
return args
|
|
@@ -11,6 +11,15 @@ def main():
|
|
|
11
11
|
if args.command == 'concat':
|
|
12
12
|
input_dir, output_dir = args.input, args.output
|
|
13
13
|
concatenate_csvs(input_dir, output_dir)
|
|
14
|
+
elif args.command == 'label_help':
|
|
15
|
+
print(f"Customize the label text for the plot. using these commands.")
|
|
16
|
+
print(' -- %m% or %methodname%: method name')
|
|
17
|
+
print(' -- %f% or %function%: function name')
|
|
18
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
19
|
+
print(' -- %pr% or %precision%: precision')
|
|
20
|
+
print(' -- %b% or %backend%: backend')
|
|
21
|
+
print(' -- %p% or %pdims%: pdims')
|
|
22
|
+
print(' -- %n% or %node%: node')
|
|
14
23
|
elif args.command == 'plot':
|
|
15
24
|
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
16
25
|
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
@@ -23,11 +32,10 @@ def main():
|
|
|
23
32
|
f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
|
|
24
33
|
)
|
|
25
34
|
# filter back the requested data sizes and gpus
|
|
26
|
-
|
|
27
|
-
args.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
35
|
+
|
|
36
|
+
args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
|
|
37
|
+
args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
|
|
38
|
+
|
|
31
39
|
if len(args.gpus) == 0:
|
|
32
40
|
print(f"No dataframes found for the given GPUs. Exiting...")
|
|
33
41
|
sys.exit(1)
|
|
@@ -10,7 +10,7 @@ from matplotlib.patches import FancyBboxPatch
|
|
|
10
10
|
from .utils import inspect_df, plot_with_pdims_strategy
|
|
11
11
|
|
|
12
12
|
np.seterr(divide='ignore')
|
|
13
|
-
plt.rcParams.update({'font.size':
|
|
13
|
+
plt.rcParams.update({'font.size': 15})
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def configure_axes(ax: Axes,
|
|
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
|
|
|
38
38
|
f2 = lambda x: np.log2(x)
|
|
39
39
|
g2 = lambda x: 2**x
|
|
40
40
|
ax.set_xlim([min(x_values), max(x_values)])
|
|
41
|
-
y_min, y_max = min(y_values) * 0.
|
|
41
|
+
y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
|
|
42
42
|
ax.set_title(title)
|
|
43
43
|
ax.set_ylim([y_min, y_max])
|
|
44
44
|
ax.set_xscale('function', functions=(f2, g2))
|
|
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
|
|
|
57
57
|
ax.legend(loc='lower center',
|
|
58
58
|
bbox_to_anchor=(0.5, 0.05),
|
|
59
59
|
ncol=4,
|
|
60
|
+
fontsize="x-large",
|
|
60
61
|
prop={'size': 14})
|
|
61
62
|
|
|
62
63
|
|
|
@@ -23,16 +23,11 @@ class Timer:
|
|
|
23
23
|
self.compiled_code = {}
|
|
24
24
|
self.save_jaxpr = save_jaxpr
|
|
25
25
|
|
|
26
|
-
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
-
if cost_analysis is None:
|
|
28
|
-
return None
|
|
29
|
-
return cost_analysis[0]['flops']
|
|
30
|
-
|
|
31
26
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
27
|
|
|
33
28
|
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
29
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
-
factor = int(np.log10(memory_analysis) // 3)
|
|
30
|
+
factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
|
|
36
31
|
|
|
37
32
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
33
|
|
|
@@ -62,11 +57,9 @@ class Timer:
|
|
|
62
57
|
compiled = lowered.compile()
|
|
63
58
|
memory_analysis = self._read_memory_analysis(
|
|
64
59
|
compiled.memory_analysis())
|
|
65
|
-
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
66
60
|
|
|
67
61
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
68
62
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
69
|
-
self.profiling_data["FLOPS"] = cost_analysis
|
|
70
63
|
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
64
|
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
65
|
self.profiling_data["output_size"] = memory_analysis[2]
|
|
@@ -138,13 +131,13 @@ class Timer:
|
|
|
138
131
|
|
|
139
132
|
times_array = self._get_mean_times()
|
|
140
133
|
if jax.process_index() == 0:
|
|
134
|
+
|
|
141
135
|
min_time = np.min(times_array)
|
|
142
136
|
max_time = np.max(times_array)
|
|
143
137
|
mean_time = np.mean(times_array)
|
|
144
138
|
std_time = np.std(times_array)
|
|
145
139
|
last_time = times_array[-1]
|
|
146
140
|
|
|
147
|
-
flops = self.profiling_data["FLOPS"]
|
|
148
141
|
generated_code = self.profiling_data["generated_code"]
|
|
149
142
|
argument_size = self.profiling_data["argument_size"]
|
|
150
143
|
output_size = self.profiling_data["output_size"]
|
|
@@ -153,7 +146,7 @@ class Timer:
|
|
|
153
146
|
csv_line = (
|
|
154
147
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
155
148
|
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
156
|
-
f"{generated_code},{argument_size},{output_size},{temp_size}
|
|
149
|
+
f"{generated_code},{argument_size},{output_size},{temp_size}\n"
|
|
157
150
|
)
|
|
158
151
|
|
|
159
152
|
with open(csv_filename, 'a') as f:
|
|
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
|
123
123
|
# Sort all and keep fastest
|
|
124
124
|
sorted_dfs = []
|
|
125
125
|
for _, group in df_decomp:
|
|
126
|
-
group.sort_values(by=[y_col], inplace=True, ascending=
|
|
126
|
+
group.sort_values(by=[y_col], inplace=True, ascending=True)
|
|
127
127
|
sorted_dfs.append(group.iloc[0])
|
|
128
128
|
sorted_df = pd.DataFrame(sorted_dfs)
|
|
129
129
|
label_params.update({
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.7}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|