jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jax_hpc_profiler/__init__.py +9 -3
- jax_hpc_profiler/create_argparse.py +128 -120
- jax_hpc_profiler/main.py +41 -22
- jax_hpc_profiler/plotting.py +250 -68
- jax_hpc_profiler/timer.py +117 -126
- jax_hpc_profiler/utils.py +191 -132
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/METADATA +36 -4
- jax_hpc_profiler-0.3.0.dist-info/RECORD +12 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/WHEEL +1 -1
- jax_hpc_profiler-0.2.12.dist-info/RECORD +0 -12
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {jax_hpc_profiler-0.2.12.dist-info → jax_hpc_profiler-0.3.0.dist-info}/top_level.txt +0 -0
jax_hpc_profiler/timer.py
CHANGED
|
@@ -1,32 +1,36 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
from functools import partial
|
|
4
|
-
from typing import Any, Callable,
|
|
4
|
+
from typing import Any, Callable, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
-
from jax import make_jaxpr
|
|
10
|
-
from jax.
|
|
11
|
-
from jax.experimental.shard_map import shard_map
|
|
12
|
-
from jax.sharding import Mesh, NamedSharding
|
|
9
|
+
from jax import make_jaxpr, shard_map
|
|
10
|
+
from jax.sharding import NamedSharding
|
|
13
11
|
from jax.sharding import PartitionSpec as P
|
|
14
12
|
from jaxtyping import Array
|
|
15
13
|
from tabulate import tabulate
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
class Timer:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
save_jaxpr=False,
|
|
20
|
+
compile_info=True,
|
|
21
|
+
jax_fn=True,
|
|
22
|
+
devices=None,
|
|
23
|
+
static_argnums=(),
|
|
24
|
+
):
|
|
26
25
|
self.jit_time = 0.0
|
|
27
26
|
self.times = []
|
|
28
|
-
self.profiling_data = {
|
|
29
|
-
|
|
27
|
+
self.profiling_data = {
|
|
28
|
+
'generated_code': 'N/A',
|
|
29
|
+
'argument_size': 'N/A',
|
|
30
|
+
'output_size': 'N/A',
|
|
31
|
+
'temp_size': 'N/A',
|
|
32
|
+
}
|
|
33
|
+
self.compiled_code = {'JAXPR': 'N/A', 'LOWERED': 'N/A', 'COMPILED': 'N/A'}
|
|
30
34
|
self.save_jaxpr = save_jaxpr
|
|
31
35
|
self.compile_info = compile_info
|
|
32
36
|
self.jax_fn = jax_fn
|
|
@@ -34,16 +38,14 @@ class Timer:
|
|
|
34
38
|
self.static_argnums = static_argnums
|
|
35
39
|
|
|
36
40
|
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
37
|
-
|
|
38
41
|
if not (self.jax_fn and self.compile_info):
|
|
39
42
|
return memory_analysis
|
|
40
43
|
|
|
41
|
-
sizes_str = [
|
|
44
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
42
45
|
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
43
|
-
factor = 0 if memory_analysis == 0 else int(
|
|
44
|
-
np.log10(memory_analysis) // 3)
|
|
46
|
+
factor = 0 if memory_analysis == 0 else int(np.log10(memory_analysis) // 3)
|
|
45
47
|
|
|
46
|
-
return f
|
|
48
|
+
return f'{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}'
|
|
47
49
|
|
|
48
50
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
49
51
|
if memory_analysis is None:
|
|
@@ -64,37 +66,33 @@ class Timer:
|
|
|
64
66
|
if isinstance(x, Array):
|
|
65
67
|
x.block_until_ready()
|
|
66
68
|
|
|
67
|
-
jax.
|
|
69
|
+
jax.tree.map(_block, out)
|
|
68
70
|
end = time.perf_counter()
|
|
69
71
|
self.jit_time = (end - start) * 1e3
|
|
70
72
|
|
|
71
|
-
self.compiled_code[
|
|
72
|
-
self.compiled_code[
|
|
73
|
-
self.compiled_code[
|
|
74
|
-
self.profiling_data[
|
|
75
|
-
self.profiling_data[
|
|
76
|
-
self.profiling_data[
|
|
77
|
-
self.profiling_data[
|
|
73
|
+
self.compiled_code['JAXPR'] = 'N/A'
|
|
74
|
+
self.compiled_code['LOWERED'] = 'N/A'
|
|
75
|
+
self.compiled_code['COMPILED'] = 'N/A'
|
|
76
|
+
self.profiling_data['generated_code'] = 'N/A'
|
|
77
|
+
self.profiling_data['argument_size'] = 'N/A'
|
|
78
|
+
self.profiling_data['output_size'] = 'N/A'
|
|
79
|
+
self.profiling_data['temp_size'] = 'N/A'
|
|
78
80
|
|
|
79
81
|
if self.save_jaxpr:
|
|
80
|
-
jaxpr = make_jaxpr(fun,
|
|
81
|
-
|
|
82
|
-
**kwargs)
|
|
83
|
-
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
82
|
+
jaxpr = make_jaxpr(fun, static_argnums=self.static_argnums)(*args, **kwargs)
|
|
83
|
+
self.compiled_code['JAXPR'] = jaxpr.pretty_print()
|
|
84
84
|
|
|
85
85
|
if self.jax_fn and self.compile_info:
|
|
86
|
-
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(
|
|
87
|
-
*args, **kwargs)
|
|
86
|
+
lowered = jax.jit(fun, static_argnums=self.static_argnums).lower(*args, **kwargs)
|
|
88
87
|
compiled = lowered.compile()
|
|
89
|
-
memory_analysis = self._read_memory_analysis(
|
|
90
|
-
compiled.memory_analysis())
|
|
88
|
+
memory_analysis = self._read_memory_analysis(compiled.memory_analysis())
|
|
91
89
|
|
|
92
|
-
self.compiled_code[
|
|
93
|
-
self.compiled_code[
|
|
94
|
-
self.profiling_data[
|
|
95
|
-
self.profiling_data[
|
|
96
|
-
self.profiling_data[
|
|
97
|
-
self.profiling_data[
|
|
90
|
+
self.compiled_code['LOWERED'] = lowered.as_text()
|
|
91
|
+
self.compiled_code['COMPILED'] = compiled.as_text()
|
|
92
|
+
self.profiling_data['generated_code'] = memory_analysis[0]
|
|
93
|
+
self.profiling_data['argument_size'] = memory_analysis[1]
|
|
94
|
+
self.profiling_data['output_size'] = memory_analysis[2]
|
|
95
|
+
self.profiling_data['temp_size'] = memory_analysis[3]
|
|
98
96
|
|
|
99
97
|
return out
|
|
100
98
|
|
|
@@ -107,7 +105,7 @@ class Timer:
|
|
|
107
105
|
if isinstance(x, Array):
|
|
108
106
|
x.block_until_ready()
|
|
109
107
|
|
|
110
|
-
jax.
|
|
108
|
+
jax.tree.map(_block, out)
|
|
111
109
|
end = time.perf_counter()
|
|
112
110
|
self.times.append((end - start) * 1e3)
|
|
113
111
|
return out
|
|
@@ -119,9 +117,8 @@ class Timer:
|
|
|
119
117
|
if self.devices is None:
|
|
120
118
|
self.devices = jax.devices()
|
|
121
119
|
|
|
122
|
-
mesh = jax.make_mesh((len(self.devices),
|
|
123
|
-
|
|
124
|
-
sharding = NamedSharding(mesh, P("x"))
|
|
120
|
+
mesh = jax.make_mesh((len(self.devices),), ('x',), devices=self.devices)
|
|
121
|
+
sharding = NamedSharding(mesh, P('x'))
|
|
125
122
|
|
|
126
123
|
times_array = jnp.array(self.times)
|
|
127
124
|
global_shape = (jax.device_count(), times_array.shape[0])
|
|
@@ -131,13 +128,9 @@ class Timer:
|
|
|
131
128
|
data_callback=lambda _: jnp.expand_dims(times_array, axis=0),
|
|
132
129
|
)
|
|
133
130
|
|
|
134
|
-
@partial(shard_map,
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
out_specs=P(),
|
|
138
|
-
check_rep=False)
|
|
139
|
-
def get_mean_times(times):
|
|
140
|
-
return jax.lax.pmean(times, axis_name="x")
|
|
131
|
+
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P(), check_rep=False)
|
|
132
|
+
def get_mean_times(times) -> Array:
|
|
133
|
+
return jax.lax.pmean(times, axis_name='x')
|
|
141
134
|
|
|
142
135
|
times_array = get_mean_times(global_times)
|
|
143
136
|
times_array.block_until_ready()
|
|
@@ -150,17 +143,17 @@ class Timer:
|
|
|
150
143
|
x: int,
|
|
151
144
|
y: int | None = None,
|
|
152
145
|
z: int | None = None,
|
|
153
|
-
precision: str =
|
|
146
|
+
precision: str = 'float32',
|
|
154
147
|
px: int = 1,
|
|
155
148
|
py: int = 1,
|
|
156
|
-
backend: str =
|
|
149
|
+
backend: str = 'NCCL',
|
|
157
150
|
nodes: int = 1,
|
|
158
151
|
md_filename: str | None = None,
|
|
159
|
-
npz_data: Optional[dict] = None,
|
|
160
|
-
extra_info: dict = {},
|
|
161
|
-
):
|
|
152
|
+
npz_data: Optional[dict[str, Any]] = None,
|
|
153
|
+
extra_info: dict[str, Any] = {},
|
|
154
|
+
) -> None:
|
|
162
155
|
if self.jit_time == 0.0 and len(self.times) == 0:
|
|
163
|
-
print(f
|
|
156
|
+
print(f'No profiling data to report for {function}')
|
|
164
157
|
return
|
|
165
158
|
|
|
166
159
|
if md_filename is None:
|
|
@@ -168,22 +161,18 @@ class Timer:
|
|
|
168
161
|
os.path.dirname(csv_filename),
|
|
169
162
|
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
170
163
|
)
|
|
171
|
-
report_folder = filename if dirname ==
|
|
164
|
+
report_folder = filename if dirname == '' else f'{dirname}/{filename}'
|
|
172
165
|
os.makedirs(report_folder, exist_ok=True)
|
|
173
|
-
md_filename =
|
|
174
|
-
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
175
|
-
)
|
|
166
|
+
md_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md'
|
|
176
167
|
|
|
177
168
|
if npz_data is not None:
|
|
178
169
|
dirname, filename = (
|
|
179
170
|
os.path.dirname(csv_filename),
|
|
180
171
|
os.path.splitext(os.path.basename(csv_filename))[0],
|
|
181
172
|
)
|
|
182
|
-
report_folder = filename if dirname ==
|
|
173
|
+
report_folder = filename if dirname == '' else f'{dirname}/{filename}'
|
|
183
174
|
os.makedirs(report_folder, exist_ok=True)
|
|
184
|
-
npz_filename =
|
|
185
|
-
f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz"
|
|
186
|
-
)
|
|
175
|
+
npz_filename = f'{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.npz'
|
|
187
176
|
np.savez(npz_filename, **npz_data)
|
|
188
177
|
|
|
189
178
|
y = x if y is None else y
|
|
@@ -191,96 +180,98 @@ class Timer:
|
|
|
191
180
|
|
|
192
181
|
times_array = self._get_mean_times()
|
|
193
182
|
if jax.process_index() == 0:
|
|
194
|
-
|
|
195
183
|
min_time = np.min(times_array)
|
|
196
184
|
max_time = np.max(times_array)
|
|
197
185
|
mean_time = np.mean(times_array)
|
|
198
186
|
std_time = np.std(times_array)
|
|
199
187
|
last_time = times_array[-1]
|
|
200
|
-
generated_code = self.profiling_data
|
|
201
|
-
argument_size = self.profiling_data
|
|
202
|
-
output_size = self.profiling_data
|
|
203
|
-
temp_size = self.profiling_data
|
|
188
|
+
generated_code = self.profiling_data.get('generated_code', 'N/A')
|
|
189
|
+
argument_size = self.profiling_data.get('argument_size', 'N/A')
|
|
190
|
+
output_size = self.profiling_data.get('output_size', 'N/A')
|
|
191
|
+
temp_size = self.profiling_data.get('temp_size', 'N/A')
|
|
204
192
|
|
|
205
193
|
csv_line = (
|
|
206
|
-
f
|
|
207
|
-
f
|
|
208
|
-
f
|
|
194
|
+
f'{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},'
|
|
195
|
+
f'{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},'
|
|
196
|
+
f'{generated_code},{argument_size},{output_size},{temp_size}\n'
|
|
209
197
|
)
|
|
210
198
|
|
|
211
|
-
with open(csv_filename,
|
|
199
|
+
with open(csv_filename, 'a') as f:
|
|
212
200
|
f.write(csv_line)
|
|
213
201
|
|
|
214
202
|
param_dict = {
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
203
|
+
'Function': function,
|
|
204
|
+
'Precision': precision,
|
|
205
|
+
'X': x,
|
|
206
|
+
'Y': y,
|
|
207
|
+
'Z': z,
|
|
208
|
+
'PX': px,
|
|
209
|
+
'PY': py,
|
|
210
|
+
'Backend': backend,
|
|
211
|
+
'Nodes': nodes,
|
|
224
212
|
}
|
|
225
213
|
param_dict.update(extra_info)
|
|
226
214
|
profiling_result = {
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
215
|
+
'JIT Time': self.jit_time,
|
|
216
|
+
'Min Time': min_time,
|
|
217
|
+
'Max Time': max_time,
|
|
218
|
+
'Mean Time': mean_time,
|
|
219
|
+
'Std Time': std_time,
|
|
220
|
+
'Last Time': last_time,
|
|
221
|
+
'Generated Code': self._normalize_memory_units(generated_code),
|
|
222
|
+
'Argument Size': self._normalize_memory_units(argument_size),
|
|
223
|
+
'Output Size': self._normalize_memory_units(output_size),
|
|
224
|
+
'Temporary Size': self._normalize_memory_units(temp_size),
|
|
237
225
|
}
|
|
238
226
|
iteration_runs = {}
|
|
239
227
|
for i in range(len(times_array)):
|
|
240
|
-
iteration_runs[f
|
|
228
|
+
iteration_runs[f'Run {i}'] = times_array[i]
|
|
241
229
|
|
|
242
|
-
with open(md_filename,
|
|
243
|
-
f.write(f
|
|
244
|
-
f.write(
|
|
230
|
+
with open(md_filename, 'w') as f:
|
|
231
|
+
f.write(f'# Reporting for {function}\n')
|
|
232
|
+
f.write('## Parameters\n')
|
|
245
233
|
f.write(
|
|
246
234
|
tabulate(
|
|
247
235
|
param_dict.items(),
|
|
248
|
-
headers=[
|
|
249
|
-
tablefmt=
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
f.write(
|
|
236
|
+
headers=['Parameter', 'Value'],
|
|
237
|
+
tablefmt='github',
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
f.write('\n---\n')
|
|
241
|
+
f.write('## Profiling Data\n')
|
|
253
242
|
f.write(
|
|
254
243
|
tabulate(
|
|
255
244
|
profiling_result.items(),
|
|
256
|
-
headers=[
|
|
257
|
-
tablefmt=
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
f.write(
|
|
245
|
+
headers=['Parameter', 'Value'],
|
|
246
|
+
tablefmt='github',
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
f.write('\n---\n')
|
|
250
|
+
f.write('## Iteration Runs\n')
|
|
261
251
|
f.write(
|
|
262
252
|
tabulate(
|
|
263
253
|
iteration_runs.items(),
|
|
264
|
-
headers=[
|
|
265
|
-
tablefmt=
|
|
266
|
-
)
|
|
254
|
+
headers=['Iteration', 'Time'],
|
|
255
|
+
tablefmt='github',
|
|
256
|
+
)
|
|
257
|
+
)
|
|
267
258
|
if self.jax_fn and self.compile_info:
|
|
268
|
-
f.write(
|
|
269
|
-
f.write(
|
|
270
|
-
f.write(
|
|
271
|
-
f.write(self.compiled_code[
|
|
272
|
-
f.write(
|
|
273
|
-
f.write(
|
|
274
|
-
f.write(
|
|
275
|
-
f.write(
|
|
276
|
-
f.write(self.compiled_code[
|
|
277
|
-
f.write(
|
|
278
|
-
f.write(
|
|
259
|
+
f.write('\n---\n')
|
|
260
|
+
f.write('## Compiled Code\n')
|
|
261
|
+
f.write('```hlo\n')
|
|
262
|
+
f.write(self.compiled_code['COMPILED'])
|
|
263
|
+
f.write('\n```\n')
|
|
264
|
+
f.write('\n---\n')
|
|
265
|
+
f.write('## Lowered Code\n')
|
|
266
|
+
f.write('```hlo\n')
|
|
267
|
+
f.write(self.compiled_code['LOWERED'])
|
|
268
|
+
f.write('\n```\n')
|
|
269
|
+
f.write('\n---\n')
|
|
279
270
|
if self.save_jaxpr:
|
|
280
|
-
f.write(
|
|
281
|
-
f.write(
|
|
282
|
-
f.write(self.compiled_code[
|
|
283
|
-
f.write(
|
|
271
|
+
f.write('## JAXPR\n')
|
|
272
|
+
f.write('```haskel\n')
|
|
273
|
+
f.write(self.compiled_code['JAXPR'])
|
|
274
|
+
f.write('\n```\n')
|
|
284
275
|
|
|
285
276
|
# Reset the timer
|
|
286
277
|
self.jit_time = 0.0
|