jax-hpc-profiler 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/timer.py +100 -95
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/METADATA +1 -1
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/RECORD +7 -7
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/WHEEL +0 -0
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.2.dist-info → jax_hpc_profiler-0.2.5.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/timer.py
CHANGED
|
@@ -28,6 +28,14 @@ class Timer:
|
|
|
28
28
|
return None
|
|
29
29
|
return cost_analysis[0]['flops']
|
|
30
30
|
|
|
31
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
+
|
|
33
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
+
factor = int(np.log10(memory_analysis) // 3)
|
|
36
|
+
|
|
37
|
+
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
|
+
|
|
31
39
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
32
40
|
if memory_analysis is None:
|
|
33
41
|
return None, None, None, None
|
|
@@ -38,7 +46,7 @@ class Timer:
|
|
|
38
46
|
|
|
39
47
|
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
40
48
|
start = time.perf_counter()
|
|
41
|
-
out =
|
|
49
|
+
out = fun(*args)
|
|
42
50
|
if ndarray_arg is None:
|
|
43
51
|
out.block_until_ready()
|
|
44
52
|
else:
|
|
@@ -59,13 +67,10 @@ class Timer:
|
|
|
59
67
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
60
68
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
61
69
|
self.profiling_data["FLOPS"] = cost_analysis
|
|
62
|
-
self.profiling_data[
|
|
63
|
-
|
|
64
|
-
self.profiling_data[
|
|
65
|
-
|
|
66
|
-
self.profiling_data[
|
|
67
|
-
"output_size"] = memory_analysis[0]
|
|
68
|
-
self.profiling_data["temp_size"] = memory_analysis[0]
|
|
70
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
73
|
+
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
69
74
|
return out
|
|
70
75
|
|
|
71
76
|
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
@@ -92,7 +97,7 @@ class Timer:
|
|
|
92
97
|
global_times = jax.make_array_from_callback(
|
|
93
98
|
shape=global_shape,
|
|
94
99
|
sharding=sharding,
|
|
95
|
-
data_callback=lambda
|
|
100
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
|
|
96
101
|
|
|
97
102
|
@partial(shard_map,
|
|
98
103
|
mesh=mesh,
|
|
@@ -104,7 +109,7 @@ class Timer:
|
|
|
104
109
|
|
|
105
110
|
times_array = get_mean_times(global_times)
|
|
106
111
|
times_array.block_until_ready()
|
|
107
|
-
return np.array(times_array.addressable_data(0))
|
|
112
|
+
return np.array(times_array.addressable_data(0)[0])
|
|
108
113
|
|
|
109
114
|
def report(self,
|
|
110
115
|
csv_filename: str,
|
|
@@ -132,89 +137,89 @@ class Timer:
|
|
|
132
137
|
z = x if z is None else z
|
|
133
138
|
|
|
134
139
|
times_array = self._get_mean_times()
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
f.write(f"\n```\n")
|
|
215
|
-
f.write("\n---\n")
|
|
216
|
-
if self.save_jaxpr:
|
|
217
|
-
f.write(f"## JAXPR\n")
|
|
218
|
-
f.write(f"```haskel\n")
|
|
219
|
-
f.write(self.compiled_code["JAXPR"])
|
|
140
|
+
if jax.process_index() == 0:
|
|
141
|
+
min_time = np.min(times_array)
|
|
142
|
+
max_time = np.max(times_array)
|
|
143
|
+
mean_time = np.mean(times_array)
|
|
144
|
+
std_time = np.std(times_array)
|
|
145
|
+
last_time = times_array[-1]
|
|
146
|
+
|
|
147
|
+
flops = self.profiling_data["FLOPS"]
|
|
148
|
+
generated_code = self.profiling_data["generated_code"]
|
|
149
|
+
argument_size = self.profiling_data["argument_size"]
|
|
150
|
+
output_size = self.profiling_data["output_size"]
|
|
151
|
+
temp_size = self.profiling_data["temp_size"]
|
|
152
|
+
|
|
153
|
+
csv_line = (
|
|
154
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
155
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
156
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
with open(csv_filename, 'a') as f:
|
|
160
|
+
f.write(csv_line)
|
|
161
|
+
|
|
162
|
+
param_dict = {
|
|
163
|
+
"Function": function,
|
|
164
|
+
"Precision": precision,
|
|
165
|
+
"X": x,
|
|
166
|
+
"Y": y,
|
|
167
|
+
"Z": z,
|
|
168
|
+
"PX": px,
|
|
169
|
+
"PY": py,
|
|
170
|
+
"Backend": backend,
|
|
171
|
+
"Nodes": nodes,
|
|
172
|
+
}
|
|
173
|
+
param_dict.update(extra_info)
|
|
174
|
+
profiling_result = {
|
|
175
|
+
"JIT Time": self.jit_time,
|
|
176
|
+
"Min Time": min_time,
|
|
177
|
+
"Max Time": max_time,
|
|
178
|
+
"Mean Time": mean_time,
|
|
179
|
+
"Std Time": std_time,
|
|
180
|
+
"Last Time": last_time,
|
|
181
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
182
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
183
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
184
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
185
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
186
|
+
}
|
|
187
|
+
iteration_runs = {}
|
|
188
|
+
for i in range(len(times_array)):
|
|
189
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
190
|
+
|
|
191
|
+
with open(md_filename, 'w') as f:
|
|
192
|
+
f.write(f"# Reporting for {function}\n")
|
|
193
|
+
f.write(f"## Parameters\n")
|
|
194
|
+
f.write(
|
|
195
|
+
tabulate(param_dict.items(),
|
|
196
|
+
headers=["Parameter", "Value"],
|
|
197
|
+
tablefmt='github'))
|
|
198
|
+
f.write("\n---\n")
|
|
199
|
+
f.write(f"## Profiling Data\n")
|
|
200
|
+
f.write(
|
|
201
|
+
tabulate(profiling_result.items(),
|
|
202
|
+
headers=["Parameter", "Value"],
|
|
203
|
+
tablefmt='github'))
|
|
204
|
+
f.write("\n---\n")
|
|
205
|
+
f.write(f"## Iteration Runs\n")
|
|
206
|
+
f.write(
|
|
207
|
+
tabulate(iteration_runs.items(),
|
|
208
|
+
headers=["Iteration", "Time"],
|
|
209
|
+
tablefmt='github'))
|
|
210
|
+
f.write("\n---\n")
|
|
211
|
+
f.write(f"## Compiled Code\n")
|
|
212
|
+
f.write(f"```hlo\n")
|
|
213
|
+
f.write(self.compiled_code["COMPILED"])
|
|
214
|
+
f.write(f"\n```\n")
|
|
215
|
+
f.write("\n---\n")
|
|
216
|
+
f.write(f"## Lowered Code\n")
|
|
217
|
+
f.write(f"```hlo\n")
|
|
218
|
+
f.write(self.compiled_code["LOWERED"])
|
|
220
219
|
f.write(f"\n```\n")
|
|
220
|
+
f.write("\n---\n")
|
|
221
|
+
if self.save_jaxpr:
|
|
222
|
+
f.write(f"## JAXPR\n")
|
|
223
|
+
f.write(f"```haskel\n")
|
|
224
|
+
f.write(self.compiled_code["JAXPR"])
|
|
225
|
+
f.write(f"\n```\n")
|
|
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
|
|
|
2
2
|
jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
|
|
3
3
|
jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
|
|
4
4
|
jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
|
|
5
|
-
jax_hpc_profiler/timer.py,sha256=
|
|
5
|
+
jax_hpc_profiler/timer.py,sha256=r4Mw2tC82cxvkMPkIy8BuZjKikgxn6cviEgmu6rpC9o,8616
|
|
6
6
|
jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
|
|
7
|
-
jax_hpc_profiler-0.2.
|
|
8
|
-
jax_hpc_profiler-0.2.
|
|
9
|
-
jax_hpc_profiler-0.2.
|
|
10
|
-
jax_hpc_profiler-0.2.
|
|
11
|
-
jax_hpc_profiler-0.2.
|
|
12
|
-
jax_hpc_profiler-0.2.
|
|
7
|
+
jax_hpc_profiler-0.2.5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
8
|
+
jax_hpc_profiler-0.2.5.dist-info/METADATA,sha256=6Vk6fA1nz-m8ZZVzSZPsg9xR9iJkFJ01x36pddG0RAM,49250
|
|
9
|
+
jax_hpc_profiler-0.2.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
10
|
+
jax_hpc_profiler-0.2.5.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
|
|
11
|
+
jax_hpc_profiler-0.2.5.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
|
|
12
|
+
jax_hpc_profiler-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|