jax-hpc-profiler 0.2.3__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/create_argparse.py +20 -18
- jax_hpc_profiler/timer.py +95 -99
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.6.dist-info}/METADATA +1 -1
- jax_hpc_profiler-0.2.6.dist-info/RECORD +12 -0
- jax_hpc_profiler-0.2.3.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.6.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.6.dist-info}/WHEEL +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.6.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.6.dist-info}/top_level.txt +0 -0
|
@@ -136,23 +136,25 @@ def create_argparser():
|
|
|
136
136
|
default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
|
|
137
137
|
|
|
138
138
|
args = parser.parse_args()
|
|
139
|
-
|
|
140
|
-
if
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
139
|
+
|
|
140
|
+
# if command was plot, then check if pdim_strategy is validat
|
|
141
|
+
if args.command == 'plot':
|
|
142
|
+
if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
143
|
+
print(
|
|
144
|
+
"Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
|
|
145
|
+
)
|
|
146
|
+
args.pdim_strategy = ['plot_all']
|
|
147
|
+
|
|
148
|
+
if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
|
|
149
|
+
print(
|
|
150
|
+
"Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
|
|
151
|
+
)
|
|
152
|
+
args.pdim_strategy = ['plot_fastest']
|
|
153
|
+
if args.plot_times is not None:
|
|
154
|
+
args.plot_columns = args.plot_times
|
|
155
|
+
elif args.plot_memory is not None:
|
|
156
|
+
args.plot_columns = args.plot_memory
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError('Either plot_times or plot_memory should be provided')
|
|
157
159
|
|
|
158
160
|
return args
|
jax_hpc_profiler/timer.py
CHANGED
|
@@ -28,15 +28,14 @@ class Timer:
|
|
|
28
28
|
return None
|
|
29
29
|
return cost_analysis[0]['flops']
|
|
30
30
|
|
|
31
|
-
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
-
|
|
33
|
-
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB'
|
|
34
|
-
factors = [1
|
|
31
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
+
|
|
33
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
35
|
factor = int(np.log10(memory_analysis) // 3)
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
38
|
|
|
39
|
-
|
|
40
39
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
41
40
|
if memory_analysis is None:
|
|
42
41
|
return None, None, None, None
|
|
@@ -68,12 +67,9 @@ class Timer:
|
|
|
68
67
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
69
68
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
70
69
|
self.profiling_data["FLOPS"] = cost_analysis
|
|
71
|
-
self.profiling_data[
|
|
72
|
-
|
|
73
|
-
self.profiling_data[
|
|
74
|
-
"argument_size"] = memory_analysis[1]
|
|
75
|
-
self.profiling_data[
|
|
76
|
-
"output_size"] = memory_analysis[2]
|
|
70
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
77
73
|
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
78
74
|
return out
|
|
79
75
|
|
|
@@ -101,7 +97,7 @@ class Timer:
|
|
|
101
97
|
global_times = jax.make_array_from_callback(
|
|
102
98
|
shape=global_shape,
|
|
103
99
|
sharding=sharding,
|
|
104
|
-
data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
|
|
100
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
|
|
105
101
|
|
|
106
102
|
@partial(shard_map,
|
|
107
103
|
mesh=mesh,
|
|
@@ -141,90 +137,90 @@ class Timer:
|
|
|
141
137
|
z = x if z is None else z
|
|
142
138
|
|
|
143
139
|
times_array = self._get_mean_times()
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
f.write("\n---\n")
|
|
220
|
-
f.write(f"## Lowered Code\n")
|
|
221
|
-
f.write(f"```hlo\n")
|
|
222
|
-
f.write(self.compiled_code["LOWERED"])
|
|
223
|
-
f.write(f"\n```\n")
|
|
224
|
-
f.write("\n---\n")
|
|
225
|
-
if self.save_jaxpr:
|
|
226
|
-
f.write(f"## JAXPR\n")
|
|
227
|
-
f.write(f"```haskel\n")
|
|
228
|
-
f.write(self.compiled_code["JAXPR"])
|
|
140
|
+
if jax.process_index() == 0:
|
|
141
|
+
|
|
142
|
+
min_time = np.min(times_array)
|
|
143
|
+
max_time = np.max(times_array)
|
|
144
|
+
mean_time = np.mean(times_array)
|
|
145
|
+
std_time = np.std(times_array)
|
|
146
|
+
last_time = times_array[-1]
|
|
147
|
+
|
|
148
|
+
flops = self.profiling_data["FLOPS"]
|
|
149
|
+
generated_code = self.profiling_data["generated_code"]
|
|
150
|
+
argument_size = self.profiling_data["argument_size"]
|
|
151
|
+
output_size = self.profiling_data["output_size"]
|
|
152
|
+
temp_size = self.profiling_data["temp_size"]
|
|
153
|
+
|
|
154
|
+
csv_line = (
|
|
155
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
156
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
157
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
with open(csv_filename, 'a') as f:
|
|
161
|
+
f.write(csv_line)
|
|
162
|
+
|
|
163
|
+
param_dict = {
|
|
164
|
+
"Function": function,
|
|
165
|
+
"Precision": precision,
|
|
166
|
+
"X": x,
|
|
167
|
+
"Y": y,
|
|
168
|
+
"Z": z,
|
|
169
|
+
"PX": px,
|
|
170
|
+
"PY": py,
|
|
171
|
+
"Backend": backend,
|
|
172
|
+
"Nodes": nodes,
|
|
173
|
+
}
|
|
174
|
+
param_dict.update(extra_info)
|
|
175
|
+
profiling_result = {
|
|
176
|
+
"JIT Time": self.jit_time,
|
|
177
|
+
"Min Time": min_time,
|
|
178
|
+
"Max Time": max_time,
|
|
179
|
+
"Mean Time": mean_time,
|
|
180
|
+
"Std Time": std_time,
|
|
181
|
+
"Last Time": last_time,
|
|
182
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
183
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
184
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
185
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
186
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
187
|
+
}
|
|
188
|
+
iteration_runs = {}
|
|
189
|
+
for i in range(len(times_array)):
|
|
190
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
191
|
+
|
|
192
|
+
with open(md_filename, 'w') as f:
|
|
193
|
+
f.write(f"# Reporting for {function}\n")
|
|
194
|
+
f.write(f"## Parameters\n")
|
|
195
|
+
f.write(
|
|
196
|
+
tabulate(param_dict.items(),
|
|
197
|
+
headers=["Parameter", "Value"],
|
|
198
|
+
tablefmt='github'))
|
|
199
|
+
f.write("\n---\n")
|
|
200
|
+
f.write(f"## Profiling Data\n")
|
|
201
|
+
f.write(
|
|
202
|
+
tabulate(profiling_result.items(),
|
|
203
|
+
headers=["Parameter", "Value"],
|
|
204
|
+
tablefmt='github'))
|
|
205
|
+
f.write("\n---\n")
|
|
206
|
+
f.write(f"## Iteration Runs\n")
|
|
207
|
+
f.write(
|
|
208
|
+
tabulate(iteration_runs.items(),
|
|
209
|
+
headers=["Iteration", "Time"],
|
|
210
|
+
tablefmt='github'))
|
|
211
|
+
f.write("\n---\n")
|
|
212
|
+
f.write(f"## Compiled Code\n")
|
|
213
|
+
f.write(f"```hlo\n")
|
|
214
|
+
f.write(self.compiled_code["COMPILED"])
|
|
229
215
|
f.write(f"\n```\n")
|
|
230
|
-
|
|
216
|
+
f.write("\n---\n")
|
|
217
|
+
f.write(f"## Lowered Code\n")
|
|
218
|
+
f.write(f"```hlo\n")
|
|
219
|
+
f.write(self.compiled_code["LOWERED"])
|
|
220
|
+
f.write(f"\n```\n")
|
|
221
|
+
f.write("\n---\n")
|
|
222
|
+
if self.save_jaxpr:
|
|
223
|
+
f.write(f"## JAXPR\n")
|
|
224
|
+
f.write(f"```haskel\n")
|
|
225
|
+
f.write(self.compiled_code["JAXPR"])
|
|
226
|
+
f.write(f"\n```\n")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
+
jax_hpc_profiler/create_argparse.py,sha256=dEicamRYqJ6GGdgcph2bwAbmdxPkS4tS12xZ4c0X_Pk,6484
|
|
3
|
+
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
|
+
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=j6oH5IZz12VJik2cE7EQ3a9tAW9C8xl7D2QLW8Bkz3s,8617
|
|
6
|
+
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
+
jax_hpc_profiler-0.2.6.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.6.dist-info/METADATA,sha256=AgLXyb89gdxgeyDv02_P5oqIvoffdTh1mc3zFPUVBAU,49250
|
|
9
|
+
jax_hpc_profiler-0.2.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
+
jax_hpc_profiler-0.2.6.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.6.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.6.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
|
|
2
|
-
jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
|
|
3
|
-
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
|
-
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=baE5DRsQBYRBiphkceTi4qI_8FPGKQEh73f2pAeS-oc,8208
|
|
6
|
-
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
-
jax_hpc_profiler-0.2.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
-
jax_hpc_profiler-0.2.3.dist-info/METADATA,sha256=myC-zD7y_pRb_-tZoSFi0KmglZH8Gk88_-U5RE14Q04,49250
|
|
9
|
-
jax_hpc_profiler-0.2.3.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
-
jax_hpc_profiler-0.2.3.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
-
jax_hpc_profiler-0.2.3.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
-
jax_hpc_profiler-0.2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|