jax-hpc-profiler 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/create_argparse.py +2 -1
- jax_hpc_profiler/main.py +13 -5
- jax_hpc_profiler/plotting.py +3 -2
- jax_hpc_profiler/timer.py +2 -10
- jax_hpc_profiler/utils.py +1 -1
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.7.dist-info}/METADATA +1 -1
- jax_hpc_profiler-0.2.7.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.7.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.6.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.7.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.7.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.6.dist-info → jax_hpc_profiler-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -67,7 +67,6 @@ def create_argparser():
|
|
|
67
67
|
plot_parser.add_argument('-fn',
|
|
68
68
|
'--function_name',
|
|
69
69
|
nargs='+',
|
|
70
|
-
default=['FFT'],
|
|
71
70
|
help='Function names to filter')
|
|
72
71
|
|
|
73
72
|
# Time or memory related arguments
|
|
@@ -135,6 +134,8 @@ def create_argparser():
|
|
|
135
134
|
'Custom label for the plot. You can use placeholders: %decomposition% (or %p%), %precision% (or %pr%), %plot_name% (or %pn%), %backend% (or %b%), %node% (or %n%), %methodname% (or %m%)',
|
|
136
135
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
|
|
137
136
|
|
|
137
|
+
subparsers.add_parser('label_help',help='Label customization help')
|
|
138
|
+
|
|
138
139
|
args = parser.parse_args()
|
|
139
140
|
|
|
140
141
|
# if command was plot, then check if pdim_strategy is validat
|
jax_hpc_profiler/main.py
CHANGED
|
@@ -11,6 +11,15 @@ def main():
|
|
|
11
11
|
if args.command == 'concat':
|
|
12
12
|
input_dir, output_dir = args.input, args.output
|
|
13
13
|
concatenate_csvs(input_dir, output_dir)
|
|
14
|
+
elif args.command == 'label_help':
|
|
15
|
+
print(f"Customize the label text for the plot. using these commands.")
|
|
16
|
+
print(' -- %m% or %methodname%: method name')
|
|
17
|
+
print(' -- %f% or %function%: function name')
|
|
18
|
+
print(' -- %pn% or %plot_name%: plot name')
|
|
19
|
+
print(' -- %pr% or %precision%: precision')
|
|
20
|
+
print(' -- %b% or %backend%: backend')
|
|
21
|
+
print(' -- %p% or %pdims%: pdims')
|
|
22
|
+
print(' -- %n% or %node%: node')
|
|
14
23
|
elif args.command == 'plot':
|
|
15
24
|
dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
|
|
16
25
|
args.csv_files, args.precision, args.function_name, args.gpus,
|
|
@@ -23,11 +32,10 @@ def main():
|
|
|
23
32
|
f"requested GPUS: {args.gpus} available GPUS: {available_gpu_counts}"
|
|
24
33
|
)
|
|
25
34
|
# filter back the requested data sizes and gpus
|
|
26
|
-
|
|
27
|
-
args.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
35
|
+
|
|
36
|
+
args.gpus = available_gpu_counts if args.gpus is None else [gpu for gpu in args.gpus if gpu in available_gpu_counts]
|
|
37
|
+
args.data_size = available_data_sizes if args.data_size is None else [data_size for data_size in args.data_size if data_size in available_data_sizes]
|
|
38
|
+
|
|
31
39
|
if len(args.gpus) == 0:
|
|
32
40
|
print(f"No dataframes found for the given GPUs. Exiting...")
|
|
33
41
|
sys.exit(1)
|
jax_hpc_profiler/plotting.py
CHANGED
|
@@ -10,7 +10,7 @@ from matplotlib.patches import FancyBboxPatch
|
|
|
10
10
|
from .utils import inspect_df, plot_with_pdims_strategy
|
|
11
11
|
|
|
12
12
|
np.seterr(divide='ignore')
|
|
13
|
-
plt.rcParams.update({'font.size':
|
|
13
|
+
plt.rcParams.update({'font.size': 15})
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def configure_axes(ax: Axes,
|
|
@@ -38,7 +38,7 @@ def configure_axes(ax: Axes,
|
|
|
38
38
|
f2 = lambda x: np.log2(x)
|
|
39
39
|
g2 = lambda x: 2**x
|
|
40
40
|
ax.set_xlim([min(x_values), max(x_values)])
|
|
41
|
-
y_min, y_max = min(y_values) * 0.
|
|
41
|
+
y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
|
|
42
42
|
ax.set_title(title)
|
|
43
43
|
ax.set_ylim([y_min, y_max])
|
|
44
44
|
ax.set_xscale('function', functions=(f2, g2))
|
|
@@ -57,6 +57,7 @@ def configure_axes(ax: Axes,
|
|
|
57
57
|
ax.legend(loc='lower center',
|
|
58
58
|
bbox_to_anchor=(0.5, 0.05),
|
|
59
59
|
ncol=4,
|
|
60
|
+
fontsize="x-large",
|
|
60
61
|
prop={'size': 14})
|
|
61
62
|
|
|
62
63
|
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -23,16 +23,11 @@ class Timer:
|
|
|
23
23
|
self.compiled_code = {}
|
|
24
24
|
self.save_jaxpr = save_jaxpr
|
|
25
25
|
|
|
26
|
-
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
-
if cost_analysis is None:
|
|
28
|
-
return None
|
|
29
|
-
return cost_analysis[0]['flops']
|
|
30
|
-
|
|
31
26
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
27
|
|
|
33
28
|
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
29
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
-
factor = int(np.log10(memory_analysis) // 3)
|
|
30
|
+
factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
|
|
36
31
|
|
|
37
32
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
33
|
|
|
@@ -62,11 +57,9 @@ class Timer:
|
|
|
62
57
|
compiled = lowered.compile()
|
|
63
58
|
memory_analysis = self._read_memory_analysis(
|
|
64
59
|
compiled.memory_analysis())
|
|
65
|
-
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
66
60
|
|
|
67
61
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
68
62
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
69
|
-
self.profiling_data["FLOPS"] = cost_analysis
|
|
70
63
|
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
64
|
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
65
|
self.profiling_data["output_size"] = memory_analysis[2]
|
|
@@ -145,7 +138,6 @@ class Timer:
|
|
|
145
138
|
std_time = np.std(times_array)
|
|
146
139
|
last_time = times_array[-1]
|
|
147
140
|
|
|
148
|
-
flops = self.profiling_data["FLOPS"]
|
|
149
141
|
generated_code = self.profiling_data["generated_code"]
|
|
150
142
|
argument_size = self.profiling_data["argument_size"]
|
|
151
143
|
output_size = self.profiling_data["output_size"]
|
|
@@ -154,7 +146,7 @@ class Timer:
|
|
|
154
146
|
csv_line = (
|
|
155
147
|
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
156
148
|
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
157
|
-
f"{generated_code},{argument_size},{output_size},{temp_size}
|
|
149
|
+
f"{generated_code},{argument_size},{output_size},{temp_size}\n"
|
|
158
150
|
)
|
|
159
151
|
|
|
160
152
|
with open(csv_filename, 'a') as f:
|
jax_hpc_profiler/utils.py
CHANGED
|
@@ -123,7 +123,7 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
|
|
|
123
123
|
# Sort all and keep fastest
|
|
124
124
|
sorted_dfs = []
|
|
125
125
|
for _, group in df_decomp:
|
|
126
|
-
group.sort_values(by=[y_col], inplace=True, ascending=
|
|
126
|
+
group.sort_values(by=[y_col], inplace=True, ascending=True)
|
|
127
127
|
sorted_dfs.append(group.iloc[0])
|
|
128
128
|
sorted_df = pd.DataFrame(sorted_dfs)
|
|
129
129
|
label_params.update({
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=m9_lg9HHxq2JDMITiHXQW1Ximua0ClwsEq1Zd9Y0hvo,6511
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=VJKvVc4m2XGJI2yp9ZF9tmmBmnTDpZ7-6LGo8ZIrWLc,2906
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=cwHznCZ2pF2J7AtyUOB3pASnahKBLRWHAPGXmGDvWas,8360
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=qPp3NcCJlMM-Cmw2mEWn63BlvPqmj_k7E8P9m0-Fy7k,8294
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=okWQUJHblUKkYnw7j7wJ75PSbhVItXKkTMKjj0BmgR0,14132
|
|
7
|
+
jax_hpc_profiler-0.2.7.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.7.dist-info/METADATA,sha256=bQkpy5Kr8ybEM7GU7qR0FEnDV7xsLbrq98GRDfgDTQU,49250
|
|
9
|
+
jax_hpc_profiler-0.2.7.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
|
|
10
|
+
jax_hpc_profiler-0.2.7.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.7.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.7.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=dEicamRYqJ6GGdgcph2bwAbmdxPkS4tS12xZ4c0X_Pk,6484
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=j6oH5IZz12VJik2cE7EQ3a9tAW9C8xl7D2QLW8Bkz3s,8617
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
-
jax_hpc_profiler-0.2.6.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.6.dist-info/METADATA,sha256=AgLXyb89gdxgeyDv02_P5oqIvoffdTh1mc3zFPUVBAU,49250
|
|
9
|
-
jax_hpc_profiler-0.2.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
-
jax_hpc_profiler-0.2.6.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.6.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|