jax-hpc-profiler 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +7 -2
- jax_hpc_profiler/create_argparse.py +109 -120
- jax_hpc_profiler/main.py +15 -19
- jax_hpc_profiler/plotting.py +58 -66
- jax_hpc_profiler/timer.py +109 -122
- jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.11.dist-info → jax_hpc_profiler-0.2.13.dist-info}/METADATA +1 -1
- jax_hpc_profiler-0.2.13.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.11.dist-info → jax_hpc_profiler-0.2.13.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.11.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.11.dist-info → jax_hpc_profiler-0.2.13.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.11.dist-info → jax_hpc_profiler-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {jax_hpc_profiler-0.2.11.dist-info → jax_hpc_profiler-0.2.13.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/timer.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
9
|
from jax import make_jaxpr
|
|
10
|
-
from jax.experimental import mesh_utils
|
|
11
10
|
from jax.experimental.shard_map import shard_map
|
|
12
|
-
from jax.sharding import
|
|
11
|
+
from jax.sharding import NamedSharding
|
|
13
12
|
from jax.sharding import PartitionSpec as P
|
|
14
13
|
from jaxtyping import Array
|
|
15
14
|
from tabulate import tabulate
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class Timer:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
save_jaxpr=False,
|
|
21
|
+
compile_info=True,
|
|
22
|
+
jax_fn=True,
|
|
23
|
+
devices=None,
|
|
24
|
+
static_argnums=(),
|
|
25
|
+
):
|
|
26
26
|
self.jit_time = 0.0
|
|
27
27
|
self.times = []
|
|
28
28
|
self.profiling_data = {}
|
|
@@ -34,16 +34,14 @@ class Timer:
|
|
|
34
34
|
self.static_argnums = static_argnums
|
|
35
35
|
|
|
36
36
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
37
|
-
|
|
38
37
|
if not (self.jax_fn and self.compile_info):
|
|
39
38
|
return memory_analysis
|
|
40
39
|
|
|
41
|
-
sizes_str = [
|
|
40
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
42
41
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
43
|
-
factor = 0 if memory_analysis == 0 else int(
|
|
44
|
-
np.log10(memory_analysis) // 3)
|
|
42
|
+
factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
|
|
45
43
|
|
|
46
|
-
return f
|
|
44
|
+
return f'{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}'
|
|
47
45
|
|
|
48
46
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
49
47
|
if memory_analysis is None:
|
|
@@ -64,37 +62,33 @@ class Timer:
|
|
|
64
62
|
if isinstance(x, Array):
|
|
65
63
|
x.block_until_ready()
|
|
66
64
|
|
|
67
|
-
jax.
|
|
65
|
+
jax.tree.map(_block, out)
|
|
68
66
|
end = time.perf_counter()
|
|
69
67
|
self.jit_time = (end - start) * 1e3
|
|
70
68
|
|
|
71
|
-
self.compiled_code[
|
|
72
|
-
self.compiled_code[
|
|
73
|
-
self.compiled_code[
|
|
74
|
-
self.profiling_data[
|
|
75
|
-
self.profiling_data[
|
|
76
|
-
self.profiling_data[
|
|
77
|
-
self.profiling_data[
|
|
69
|
+
self.compiled_code['JAXPR'] = 'N/A'
|
|
70
|
+
self.compiled_code['LOWERED'] = 'N/A'
|
|
71
|
+
self.compiled_code['COMPILED'] = 'N/A'
|
|
72
|
+
self.profiling_data['generated_code'] = 'N/A'
|
|
73
|
+
self.profiling_data['argument_size'] = 'N/A'
|
|
74
|
+
self.profiling_data['output_size'] = 'N/A'
|
|
75
|
+
self.profiling_data['temp_size'] = 'N/A'
|
|
78
76
|
|
|
79
77
|
if self.save_jaxpr:
|
|
80
|
-
jaxpr = make_jaxpr(fun,
|
|
81
|
-
|
|
82
|
-
**kwargs)
|
|
83
|
-
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
78
|
+
jaxpr = make_jaxpr(fun, static_argnums=self.static_argnums)(*args, **kwargs)
|
|
79
|
+
self.compiled_code['JAXPR'] = jaxpr.pretty_print()
|
|
84
80
|
|
|
85
81
|
if self.jax_fn and self.compile_info:
|
|
86
|
-
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
|
|
87
|
-
*args, **kwargs)
|
|
82
|
+
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(*args, **kwargs)
|
|
88
83
|
compiled = lowered.compile()
|
|
89
|
-
memory_analysis = self._read_memory_analysis(
|
|
90
|
-
compiled.memory_analysis())
|
|
84
|
+
memory_analysis = self._read_memory_analysis(compiled.memory_analysis())
|
|
91
85
|
|
|
92
|
-
self.compiled_code[
|
|
93
|
-
self.compiled_code[
|
|
94
|
-
self.profiling_data[
|
|
95
|
-
self.profiling_data[
|
|
96
|
-
self.profiling_data[
|
|
97
|
-
self.profiling_data[
|
|
86
|
+
self.compiled_code['LOWERED'] = lowered.as_text()
|
|
87
|
+
self.compiled_code['COMPILED'] = compiled.as_text()
|
|
88
|
+
self.profiling_data['generated_code'] = memory_analysis[0]
|
|
89
|
+
self.profiling_data['argument_size'] = memory_analysis[1]
|
|
90
|
+
self.profiling_data['output_size'] = memory_analysis[2]
|
|
91
|
+
self.profiling_data['temp_size'] = memory_analysis[3]
|
|
98
92
|
|
|
99
93
|
return out
|
|
100
94
|
|
|
@@ -107,7 +101,7 @@ class Timer:
|
|
|
107
101
|
if isinstance(x, Array):
|
|
108
102
|
x.block_until_ready()
|
|
109
103
|
|
|
110
|
-
jax.
|
|
104
|
+
jax.tree.map(_block, out)
|
|
111
105
|
end = time.perf_counter()
|
|
112
106
|
self.times.append((end - start) * 1e3)
|
|
113
107
|
return out
|
|
@@ -119,9 +113,8 @@ class Timer:
|
|
|
119
113
|
if self.devices is None:
|
|
120
114
|
self.devices = jax.devices()
|
|
121
115
|
|
|
122
|
-
mesh = jax.make_mesh((len(self.devices),
|
|
123
|
-
|
|
124
|
-
sharding = NamedSharding(mesh, P("x"))
|
|
116
|
+
mesh = jax.make_mesh((len(self.devices),), ('x',), devices=self.devices)
|
|
117
|
+
sharding = NamedSharding(mesh, P('x'))
|
|
125
118
|
|
|
126
119
|
times_array = jnp.array(self.times)
|
|
127
120
|
global_shape = (jax.device_count(), times_array.shape[0])
|
|
@@ -131,13 +124,9 @@ class Timer:
|
|
|
131
124
|
data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
|
|
132
125
|
)
|
|
133
126
|
|
|
134
|
-
@partial(shard_map,
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
out_specs=P(),
|
|
138
|
-
check_rep=False)
|
|
139
|
-
def get_mean_times(times):
|
|
140
|
-
return jax.lax.pmean(times, axis_name="x")
|
|
127
|
+
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P(), check_rep=False)
|
|
128
|
+
def get_mean_times(times) -> Array:
|
|
129
|
+
return jax.lax.pmean(times, axis_name='x')
|
|
141
130
|
|
|
142
131
|
times_array = get_mean_times(global_times)
|
|
143
132
|
times_array.block_until_ready()
|
|
@@ -150,17 +139,17 @@ class Timer:
|
|
|
150
139
|
x: int,
|
|
151
140
|
y: int | None = None,
|
|
152
141
|
z: int | None = None,
|
|
153
|
-
precision: str =
|
|
142
|
+
precision: str = 'float32',
|
|
154
143
|
px: int = 1,
|
|
155
144
|
py: int = 1,
|
|
156
|
-
backend: str =
|
|
145
|
+
backend: str = 'NCCL',
|
|
157
146
|
nodes: int = 1,
|
|
158
147
|
md_filename: str | None = None,
|
|
159
|
-
npz_data: Optional[dict] = None,
|
|
160
|
-
extra_info: dict = {},
|
|
161
|
-
):
|
|
148
|
+
npz_data: Optional[dict[str, Any]] = None,
|
|
149
|
+
extra_info: dict[str, Any] = {},
|
|
150
|
+
) -> None:
|
|
162
151
|
if self.jit_time == 0.0 and len(self.times) == 0:
|
|
163
|
-
print(f
|
|
152
|
+
print(f'No profiling data to report for {function}')
|
|
164
153
|
return
|
|
165
154
|
|
|
166
155
|
if md_filename is None:
|
|
@@ -168,22 +157,18 @@ class Timer:
|
|
|
168
157
|
os.path.dirname(csv_filename),
|
|
169
158
|
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
170
159
|
)
|
|
171
|
-
report_folder = filename if dirname ==
|
|
160
|
+
report_folder = filename if dirname == '' else f'{dirname}/{filename}'
|
|
172
161
|
os.makedirs(report_folder, exist_ok=True)
|
|
173
|
-
md_filename =
|
|
174
|
-
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
175
|
-
)
|
|
162
|
+
md_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md'
|
|
176
163
|
|
|
177
164
|
if npz_data is not None:
|
|
178
165
|
dirname, filename = (
|
|
179
166
|
os.path.dirname(csv_filename),
|
|
180
167
|
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
181
168
|
)
|
|
182
|
-
report_folder = filename if dirname ==
|
|
169
|
+
report_folder = filename if dirname == '' else f'{dirname}/{filename}'
|
|
183
170
|
os.makedirs(report_folder, exist_ok=True)
|
|
184
|
-
npz_filename =
|
|
185
|
-
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
|
|
186
|
-
)
|
|
171
|
+
npz_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz'
|
|
187
172
|
np.savez(npz_filename, **npz_data)
|
|
188
173
|
|
|
189
174
|
y = x if y is None else y
|
|
@@ -191,96 +176,98 @@ class Timer:
|
|
|
191
176
|
|
|
192
177
|
times_array = self._get_mean_times()
|
|
193
178
|
if jax.process_index() == 0:
|
|
194
|
-
|
|
195
179
|
min_time = np.min(times_array)
|
|
196
180
|
max_time = np.max(times_array)
|
|
197
181
|
mean_time = np.mean(times_array)
|
|
198
182
|
std_time = np.std(times_array)
|
|
199
183
|
last_time = times_array[-1]
|
|
200
|
-
generated_code = self.profiling_data[
|
|
201
|
-
argument_size = self.profiling_data[
|
|
202
|
-
output_size = self.profiling_data[
|
|
203
|
-
temp_size = self.profiling_data[
|
|
184
|
+
generated_code = self.profiling_data['generated_code']
|
|
185
|
+
argument_size = self.profiling_data['argument_size']
|
|
186
|
+
output_size = self.profiling_data['output_size']
|
|
187
|
+
temp_size = self.profiling_data['temp_size']
|
|
204
188
|
|
|
205
189
|
csv_line = (
|
|
206
|
-
f
|
|
207
|
-
f
|
|
208
|
-
f
|
|
190
|
+
f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
|
|
191
|
+
f'{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},'
|
|
192
|
+
f'{generated_code},{argument_size},{output_size},{temp_size}\n'
|
|
209
193
|
)
|
|
210
194
|
|
|
211
|
-
with open(csv_filename,
|
|
195
|
+
with open(csv_filename, 'a') as f:
|
|
212
196
|
f.write(csv_line)
|
|
213
197
|
|
|
214
198
|
param_dict = {
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
199
|
+
'Function': function,
|
|
200
|
+
'Precision': precision,
|
|
201
|
+
'X': x,
|
|
202
|
+
'Y': y,
|
|
203
|
+
'Z': z,
|
|
204
|
+
'PX': px,
|
|
205
|
+
'PY': py,
|
|
206
|
+
'Backend': backend,
|
|
207
|
+
'Nodes': nodes,
|
|
224
208
|
}
|
|
225
209
|
param_dict.update(extra_info)
|
|
226
210
|
profiling_result = {
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
211
|
+
'JIT Time': self.jit_time,
|
|
212
|
+
'Min Time': min_time,
|
|
213
|
+
'Max Time': max_time,
|
|
214
|
+
'Mean Time': mean_time,
|
|
215
|
+
'Std Time': std_time,
|
|
216
|
+
'Last Time': last_time,
|
|
217
|
+
'Generated Code': self._normalize_memory_units(generated_code),
|
|
218
|
+
'Argument Size': self._normalize_memory_units(argument_size),
|
|
219
|
+
'Output Size': self._normalize_memory_units(output_size),
|
|
220
|
+
'Temporary Size': self._normalize_memory_units(temp_size),
|
|
237
221
|
}
|
|
238
222
|
iteration_runs = {}
|
|
239
223
|
for i in range(len(times_array)):
|
|
240
|
-
iteration_runs[f
|
|
224
|
+
iteration_runs[f'Run {i}'] = times_array[i]
|
|
241
225
|
|
|
242
|
-
with open(md_filename,
|
|
243
|
-
f.write(f
|
|
244
|
-
f.write(
|
|
226
|
+
with open(md_filename, 'w') as f:
|
|
227
|
+
f.write(f'# Reporting for {function}\n')
|
|
228
|
+
f.write('## Parameters\n')
|
|
245
229
|
f.write(
|
|
246
230
|
tabulate(
|
|
247
231
|
param_dict.items(),
|
|
248
|
-
headers=[
|
|
249
|
-
tablefmt=
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
f.write(
|
|
232
|
+
headers=['Parameter', 'Value'],
|
|
233
|
+
tablefmt='github',
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
f.write('\n---\n')
|
|
237
|
+
f.write('## Profiling Data\n')
|
|
253
238
|
f.write(
|
|
254
239
|
tabulate(
|
|
255
240
|
profiling_result.items(),
|
|
256
|
-
headers=[
|
|
257
|
-
tablefmt=
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
f.write(
|
|
241
|
+
headers=['Parameter', 'Value'],
|
|
242
|
+
tablefmt='github',
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
f.write('\n---\n')
|
|
246
|
+
f.write('## Iteration Runs\n')
|
|
261
247
|
f.write(
|
|
262
248
|
tabulate(
|
|
263
249
|
iteration_runs.items(),
|
|
264
|
-
headers=[
|
|
265
|
-
tablefmt=
|
|
266
|
-
)
|
|
250
|
+
headers=['Iteration', 'Time'],
|
|
251
|
+
tablefmt='github',
|
|
252
|
+
)
|
|
253
|
+
)
|
|
267
254
|
if self.jax_fn and self.compile_info:
|
|
268
|
-
f.write(
|
|
269
|
-
f.write(
|
|
270
|
-
f.write(
|
|
271
|
-
f.write(self.compiled_code[
|
|
272
|
-
f.write(
|
|
273
|
-
f.write(
|
|
274
|
-
f.write(
|
|
275
|
-
f.write(
|
|
276
|
-
f.write(self.compiled_code[
|
|
277
|
-
f.write(
|
|
278
|
-
f.write(
|
|
255
|
+
f.write('\n---\n')
|
|
256
|
+
f.write('## Compiled Code\n')
|
|
257
|
+
f.write('```hlo\n')
|
|
258
|
+
f.write(self.compiled_code['COMPILED'])
|
|
259
|
+
f.write('\n```\n')
|
|
260
|
+
f.write('\n---\n')
|
|
261
|
+
f.write('## Lowered Code\n')
|
|
262
|
+
f.write('```hlo\n')
|
|
263
|
+
f.write(self.compiled_code['LOWERED'])
|
|
264
|
+
f.write('\n```\n')
|
|
265
|
+
f.write('\n---\n')
|
|
279
266
|
if self.save_jaxpr:
|
|
280
|
-
f.write(
|
|
281
|
-
f.write(
|
|
282
|
-
f.write(self.compiled_code[
|
|
283
|
-
f.write(
|
|
267
|
+
f.write('## JAXPR\n')
|
|
268
|
+
f.write('```haskel\n')
|
|
269
|
+
f.write(self.compiled_code['JAXPR'])
|
|
270
|
+
f.write('\n```\n')
|
|
284
271
|
|
|
285
272
|
# Reset the timer
|
|
286
273
|
self.jit_time = 0.0
|