jax-hpc-profiler 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/timer.py +94 -99
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/METADATA +1 -1
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/RECORD +7 -7
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/WHEEL +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.3.dist-info → jax_hpc_profiler-0.2.5.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/timer.py
CHANGED
|
@@ -28,15 +28,14 @@ class Timer:
|
|
|
28
28
|
return None
|
|
29
29
|
return cost_analysis[0]['flops']
|
|
30
30
|
|
|
31
|
-
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
-
|
|
33
|
-
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB'
|
|
34
|
-
factors = [1
|
|
31
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
+
|
|
33
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
35
|
factor = int(np.log10(memory_analysis) // 3)
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
38
|
|
|
39
|
-
|
|
40
39
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
41
40
|
if memory_analysis is None:
|
|
42
41
|
return None, None, None, None
|
|
@@ -68,12 +67,9 @@ class Timer:
|
|
|
68
67
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
69
68
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
70
69
|
self.profiling_data["FLOPS"] = cost_analysis
|
|
71
|
-
self.profiling_data[
|
|
72
|
-
|
|
73
|
-
self.profiling_data[
|
|
74
|
-
"argument_size"] = memory_analysis[1]
|
|
75
|
-
self.profiling_data[
|
|
76
|
-
"output_size"] = memory_analysis[2]
|
|
70
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
77
73
|
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
78
74
|
return out
|
|
79
75
|
|
|
@@ -101,7 +97,7 @@ class Timer:
|
|
|
101
97
|
global_times = jax.make_array_from_callback(
|
|
102
98
|
shape=global_shape,
|
|
103
99
|
sharding=sharding,
|
|
104
|
-
data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
|
|
100
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
|
|
105
101
|
|
|
106
102
|
@partial(shard_map,
|
|
107
103
|
mesh=mesh,
|
|
@@ -141,90 +137,89 @@ class Timer:
|
|
|
141
137
|
z = x if z is None else z
|
|
142
138
|
|
|
143
139
|
times_array = self._get_mean_times()
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
f.write(f"\n```\n")
|
|
219
|
-
f.write("\n---\n")
|
|
220
|
-
f.write(f"## Lowered Code\n")
|
|
221
|
-
f.write(f"```hlo\n")
|
|
222
|
-
f.write(self.compiled_code["LOWERED"])
|
|
223
|
-
f.write(f"\n```\n")
|
|
224
|
-
f.write("\n---\n")
|
|
225
|
-
if self.save_jaxpr:
|
|
226
|
-
f.write(f"## JAXPR\n")
|
|
227
|
-
f.write(f"```haskel\n")
|
|
228
|
-
f.write(self.compiled_code["JAXPR"])
|
|
140
|
+
if jax.process_index() == 0:
|
|
141
|
+
min_time = np.min(times_array)
|
|
142
|
+
max_time = np.max(times_array)
|
|
143
|
+
mean_time = np.mean(times_array)
|
|
144
|
+
std_time = np.std(times_array)
|
|
145
|
+
last_time = times_array[-1]
|
|
146
|
+
|
|
147
|
+
flops = self.profiling_data["FLOPS"]
|
|
148
|
+
generated_code = self.profiling_data["generated_code"]
|
|
149
|
+
argument_size = self.profiling_data["argument_size"]
|
|
150
|
+
output_size = self.profiling_data["output_size"]
|
|
151
|
+
temp_size = self.profiling_data["temp_size"]
|
|
152
|
+
|
|
153
|
+
csv_line = (
|
|
154
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
155
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
156
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
with open(csv_filename, 'a') as f:
|
|
160
|
+
f.write(csv_line)
|
|
161
|
+
|
|
162
|
+
param_dict = {
|
|
163
|
+
"Function": function,
|
|
164
|
+
"Precision": precision,
|
|
165
|
+
"X": x,
|
|
166
|
+
"Y": y,
|
|
167
|
+
"Z": z,
|
|
168
|
+
"PX": px,
|
|
169
|
+
"PY": py,
|
|
170
|
+
"Backend": backend,
|
|
171
|
+
"Nodes": nodes,
|
|
172
|
+
}
|
|
173
|
+
param_dict.update(extra_info)
|
|
174
|
+
profiling_result = {
|
|
175
|
+
"JIT Time": self.jit_time,
|
|
176
|
+
"Min Time": min_time,
|
|
177
|
+
"Max Time": max_time,
|
|
178
|
+
"Mean Time": mean_time,
|
|
179
|
+
"Std Time": std_time,
|
|
180
|
+
"Last Time": last_time,
|
|
181
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
182
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
183
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
184
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
185
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
186
|
+
}
|
|
187
|
+
iteration_runs = {}
|
|
188
|
+
for i in range(len(times_array)):
|
|
189
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
190
|
+
|
|
191
|
+
with open(md_filename, 'w') as f:
|
|
192
|
+
f.write(f"# Reporting for {function}\n")
|
|
193
|
+
f.write(f"## Parameters\n")
|
|
194
|
+
f.write(
|
|
195
|
+
tabulate(param_dict.items(),
|
|
196
|
+
headers=["Parameter", "Value"],
|
|
197
|
+
tablefmt='github'))
|
|
198
|
+
f.write("\n---\n")
|
|
199
|
+
f.write(f"## Profiling Data\n")
|
|
200
|
+
f.write(
|
|
201
|
+
tabulate(profiling_result.items(),
|
|
202
|
+
headers=["Parameter", "Value"],
|
|
203
|
+
tablefmt='github'))
|
|
204
|
+
f.write("\n---\n")
|
|
205
|
+
f.write(f"## Iteration Runs\n")
|
|
206
|
+
f.write(
|
|
207
|
+
tabulate(iteration_runs.items(),
|
|
208
|
+
headers=["Iteration", "Time"],
|
|
209
|
+
tablefmt='github'))
|
|
210
|
+
f.write("\n---\n")
|
|
211
|
+
f.write(f"## Compiled Code\n")
|
|
212
|
+
f.write(f"```hlo\n")
|
|
213
|
+
f.write(self.compiled_code["COMPILED"])
|
|
229
214
|
f.write(f"\n```\n")
|
|
230
|
-
|
|
215
|
+
f.write("\n---\n")
|
|
216
|
+
f.write(f"## Lowered Code\n")
|
|
217
|
+
f.write(f"```hlo\n")
|
|
218
|
+
f.write(self.compiled_code["LOWERED"])
|
|
219
|
+
f.write(f"\n```\n")
|
|
220
|
+
f.write("\n---\n")
|
|
221
|
+
if self.save_jaxpr:
|
|
222
|
+
f.write(f"## JAXPR\n")
|
|
223
|
+
f.write(f"```haskel\n")
|
|
224
|
+
f.write(self.compiled_code["JAXPR"])
|
|
225
|
+
f.write(f"\n```\n")
|
|
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
|
|
|
2
2
|
jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
|
|
3
3
|
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
4
|
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=r4Mw2tC82cxvkMPkIy8BuZjKikgxn6cviEgmu6rpC9o,8616
|
|
6
6
|
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
-
jax_hpc_profiler-0.2.
|
|
8
|
-
jax_hpc_profiler-0.2.
|
|
9
|
-
jax_hpc_profiler-0.2.
|
|
10
|
-
jax_hpc_profiler-0.2.
|
|
11
|
-
jax_hpc_profiler-0.2.
|
|
12
|
-
jax_hpc_profiler-0.2.
|
|
7
|
+
jax_hpc_profiler-0.2.5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.5.dist-info/METADATA,sha256=6Vk6fA1nz-m8ZZVzSZPsg9xR9iJkFJ01x36pddG0RAM,49250
|
|
9
|
+
jax_hpc_profiler-0.2.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
+
jax_hpc_profiler-0.2.5.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.5.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|