jax-hpc-profiler 0.2.2__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/pyproject.toml +1 -1
- jax_hpc_profiler-0.2.5/src/jax_hpc_profiler/timer.py +225 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
- jax_hpc_profiler-0.2.2/src/jax_hpc_profiler/timer.py +0 -220
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/README.md +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/create_argparse.py +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/main.py +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/plotting.py +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/utils.py +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any, Callable, List, Tuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from jax import make_jaxpr
|
|
10
|
+
from jax.experimental import mesh_utils
|
|
11
|
+
from jax.experimental.shard_map import shard_map
|
|
12
|
+
from jax.sharding import Mesh, NamedSharding
|
|
13
|
+
from jax.sharding import PartitionSpec as P
|
|
14
|
+
from tabulate import tabulate
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Timer:
|
|
18
|
+
|
|
19
|
+
def __init__(self, save_jaxpr=False):
|
|
20
|
+
self.jit_time = None
|
|
21
|
+
self.times = []
|
|
22
|
+
self.profiling_data = {}
|
|
23
|
+
self.compiled_code = {}
|
|
24
|
+
self.save_jaxpr = save_jaxpr
|
|
25
|
+
|
|
26
|
+
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
+
if cost_analysis is None:
|
|
28
|
+
return None
|
|
29
|
+
return cost_analysis[0]['flops']
|
|
30
|
+
|
|
31
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
+
|
|
33
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
|
+
factor = int(np.log10(memory_analysis) // 3)
|
|
36
|
+
|
|
37
|
+
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
|
+
|
|
39
|
+
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
40
|
+
if memory_analysis is None:
|
|
41
|
+
return None, None, None, None
|
|
42
|
+
return (memory_analysis.generated_code_size_in_bytes,
|
|
43
|
+
memory_analysis.argument_size_in_bytes,
|
|
44
|
+
memory_analysis.output_size_in_bytes,
|
|
45
|
+
memory_analysis.temp_size_in_bytes)
|
|
46
|
+
|
|
47
|
+
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
48
|
+
start = time.perf_counter()
|
|
49
|
+
out = fun(*args)
|
|
50
|
+
if ndarray_arg is None:
|
|
51
|
+
out.block_until_ready()
|
|
52
|
+
else:
|
|
53
|
+
out[ndarray_arg].block_until_ready()
|
|
54
|
+
end = time.perf_counter()
|
|
55
|
+
self.jit_time = (end - start) * 1e3
|
|
56
|
+
|
|
57
|
+
if self.save_jaxpr:
|
|
58
|
+
jaxpr = make_jaxpr(fun)(*args)
|
|
59
|
+
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
60
|
+
|
|
61
|
+
lowered = jax.jit(fun).lower(*args)
|
|
62
|
+
compiled = lowered.compile()
|
|
63
|
+
memory_analysis = self._read_memory_analysis(
|
|
64
|
+
compiled.memory_analysis())
|
|
65
|
+
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
66
|
+
|
|
67
|
+
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
68
|
+
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
69
|
+
self.profiling_data["FLOPS"] = cost_analysis
|
|
70
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
73
|
+
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
74
|
+
return out
|
|
75
|
+
|
|
76
|
+
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
77
|
+
start = time.perf_counter()
|
|
78
|
+
out = fun(*args)
|
|
79
|
+
if ndarray_arg is None:
|
|
80
|
+
out.block_until_ready()
|
|
81
|
+
else:
|
|
82
|
+
out[ndarray_arg].block_until_ready()
|
|
83
|
+
end = time.perf_counter()
|
|
84
|
+
self.times.append((end - start) * 1e3)
|
|
85
|
+
return out
|
|
86
|
+
|
|
87
|
+
def _get_mean_times(self) -> np.ndarray:
|
|
88
|
+
if jax.device_count() == 1:
|
|
89
|
+
return np.array(self.times)
|
|
90
|
+
|
|
91
|
+
devices = mesh_utils.create_device_mesh((jax.device_count(), ))
|
|
92
|
+
mesh = Mesh(devices, ('x', ))
|
|
93
|
+
sharding = NamedSharding(mesh, P('x'))
|
|
94
|
+
|
|
95
|
+
times_array = jnp.array(self.times)
|
|
96
|
+
global_shape = (jax.device_count(), times_array.shape[0])
|
|
97
|
+
global_times = jax.make_array_from_callback(
|
|
98
|
+
shape=global_shape,
|
|
99
|
+
sharding=sharding,
|
|
100
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
|
|
101
|
+
|
|
102
|
+
@partial(shard_map,
|
|
103
|
+
mesh=mesh,
|
|
104
|
+
in_specs=P('x'),
|
|
105
|
+
out_specs=P(),
|
|
106
|
+
check_rep=False)
|
|
107
|
+
def get_mean_times(times):
|
|
108
|
+
return jax.lax.pmean(times, axis_name='x')
|
|
109
|
+
|
|
110
|
+
times_array = get_mean_times(global_times)
|
|
111
|
+
times_array.block_until_ready()
|
|
112
|
+
return np.array(times_array.addressable_data(0)[0])
|
|
113
|
+
|
|
114
|
+
def report(self,
|
|
115
|
+
csv_filename: str,
|
|
116
|
+
function: str,
|
|
117
|
+
x: int,
|
|
118
|
+
y: int | None = None,
|
|
119
|
+
z: int | None = None,
|
|
120
|
+
precision: str = "float32",
|
|
121
|
+
px: int = 1,
|
|
122
|
+
py: int = 1,
|
|
123
|
+
backend: str = "NCCL",
|
|
124
|
+
nodes: int = 1,
|
|
125
|
+
md_filename: str | None = None,
|
|
126
|
+
extra_info: dict = {}):
|
|
127
|
+
|
|
128
|
+
if md_filename is None:
|
|
129
|
+
dirname, filename = os.path.dirname(
|
|
130
|
+
csv_filename), os.path.splitext(
|
|
131
|
+
os.path.basename(csv_filename))[0]
|
|
132
|
+
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
133
|
+
os.makedirs(report_folder, exist_ok=True)
|
|
134
|
+
md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
135
|
+
|
|
136
|
+
y = x if y is None else y
|
|
137
|
+
z = x if z is None else z
|
|
138
|
+
|
|
139
|
+
times_array = self._get_mean_times()
|
|
140
|
+
if jax.process_index() == 0:
|
|
141
|
+
min_time = np.min(times_array)
|
|
142
|
+
max_time = np.max(times_array)
|
|
143
|
+
mean_time = np.mean(times_array)
|
|
144
|
+
std_time = np.std(times_array)
|
|
145
|
+
last_time = times_array[-1]
|
|
146
|
+
|
|
147
|
+
flops = self.profiling_data["FLOPS"]
|
|
148
|
+
generated_code = self.profiling_data["generated_code"]
|
|
149
|
+
argument_size = self.profiling_data["argument_size"]
|
|
150
|
+
output_size = self.profiling_data["output_size"]
|
|
151
|
+
temp_size = self.profiling_data["temp_size"]
|
|
152
|
+
|
|
153
|
+
csv_line = (
|
|
154
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
155
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
156
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
with open(csv_filename, 'a') as f:
|
|
160
|
+
f.write(csv_line)
|
|
161
|
+
|
|
162
|
+
param_dict = {
|
|
163
|
+
"Function": function,
|
|
164
|
+
"Precision": precision,
|
|
165
|
+
"X": x,
|
|
166
|
+
"Y": y,
|
|
167
|
+
"Z": z,
|
|
168
|
+
"PX": px,
|
|
169
|
+
"PY": py,
|
|
170
|
+
"Backend": backend,
|
|
171
|
+
"Nodes": nodes,
|
|
172
|
+
}
|
|
173
|
+
param_dict.update(extra_info)
|
|
174
|
+
profiling_result = {
|
|
175
|
+
"JIT Time": self.jit_time,
|
|
176
|
+
"Min Time": min_time,
|
|
177
|
+
"Max Time": max_time,
|
|
178
|
+
"Mean Time": mean_time,
|
|
179
|
+
"Std Time": std_time,
|
|
180
|
+
"Last Time": last_time,
|
|
181
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
182
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
183
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
184
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
185
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
186
|
+
}
|
|
187
|
+
iteration_runs = {}
|
|
188
|
+
for i in range(len(times_array)):
|
|
189
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
190
|
+
|
|
191
|
+
with open(md_filename, 'w') as f:
|
|
192
|
+
f.write(f"# Reporting for {function}\n")
|
|
193
|
+
f.write(f"## Parameters\n")
|
|
194
|
+
f.write(
|
|
195
|
+
tabulate(param_dict.items(),
|
|
196
|
+
headers=["Parameter", "Value"],
|
|
197
|
+
tablefmt='github'))
|
|
198
|
+
f.write("\n---\n")
|
|
199
|
+
f.write(f"## Profiling Data\n")
|
|
200
|
+
f.write(
|
|
201
|
+
tabulate(profiling_result.items(),
|
|
202
|
+
headers=["Parameter", "Value"],
|
|
203
|
+
tablefmt='github'))
|
|
204
|
+
f.write("\n---\n")
|
|
205
|
+
f.write(f"## Iteration Runs\n")
|
|
206
|
+
f.write(
|
|
207
|
+
tabulate(iteration_runs.items(),
|
|
208
|
+
headers=["Iteration", "Time"],
|
|
209
|
+
tablefmt='github'))
|
|
210
|
+
f.write("\n---\n")
|
|
211
|
+
f.write(f"## Compiled Code\n")
|
|
212
|
+
f.write(f"```hlo\n")
|
|
213
|
+
f.write(self.compiled_code["COMPILED"])
|
|
214
|
+
f.write(f"\n```\n")
|
|
215
|
+
f.write("\n---\n")
|
|
216
|
+
f.write(f"## Lowered Code\n")
|
|
217
|
+
f.write(f"```hlo\n")
|
|
218
|
+
f.write(self.compiled_code["LOWERED"])
|
|
219
|
+
f.write(f"\n```\n")
|
|
220
|
+
f.write("\n---\n")
|
|
221
|
+
if self.save_jaxpr:
|
|
222
|
+
f.write(f"## JAXPR\n")
|
|
223
|
+
f.write(f"```haskel\n")
|
|
224
|
+
f.write(self.compiled_code["JAXPR"])
|
|
225
|
+
f.write(f"\n```\n")
|
|
@@ -1,220 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import time
|
|
3
|
-
from functools import partial
|
|
4
|
-
from typing import Any, Callable, List, Tuple
|
|
5
|
-
|
|
6
|
-
import jax
|
|
7
|
-
import jax.numpy as jnp
|
|
8
|
-
import numpy as np
|
|
9
|
-
from jax import make_jaxpr
|
|
10
|
-
from jax.experimental import mesh_utils
|
|
11
|
-
from jax.experimental.shard_map import shard_map
|
|
12
|
-
from jax.sharding import Mesh, NamedSharding
|
|
13
|
-
from jax.sharding import PartitionSpec as P
|
|
14
|
-
from tabulate import tabulate
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class Timer:
|
|
18
|
-
|
|
19
|
-
def __init__(self, save_jaxpr=False):
|
|
20
|
-
self.jit_time = None
|
|
21
|
-
self.times = []
|
|
22
|
-
self.profiling_data = {}
|
|
23
|
-
self.compiled_code = {}
|
|
24
|
-
self.save_jaxpr = save_jaxpr
|
|
25
|
-
|
|
26
|
-
def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
|
|
27
|
-
if cost_analysis is None:
|
|
28
|
-
return None
|
|
29
|
-
return cost_analysis[0]['flops']
|
|
30
|
-
|
|
31
|
-
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
32
|
-
if memory_analysis is None:
|
|
33
|
-
return None, None, None, None
|
|
34
|
-
return (memory_analysis.generated_code_size_in_bytes,
|
|
35
|
-
memory_analysis.argument_size_in_bytes,
|
|
36
|
-
memory_analysis.output_size_in_bytes,
|
|
37
|
-
memory_analysis.temp_size_in_bytes)
|
|
38
|
-
|
|
39
|
-
def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
40
|
-
start = time.perf_counter()
|
|
41
|
-
out = jax.jit(fun)(*args)
|
|
42
|
-
if ndarray_arg is None:
|
|
43
|
-
out.block_until_ready()
|
|
44
|
-
else:
|
|
45
|
-
out[ndarray_arg].block_until_ready()
|
|
46
|
-
end = time.perf_counter()
|
|
47
|
-
self.jit_time = (end - start) * 1e3
|
|
48
|
-
|
|
49
|
-
if self.save_jaxpr:
|
|
50
|
-
jaxpr = make_jaxpr(fun)(*args)
|
|
51
|
-
self.compiled_code["JAXPR"] = jaxpr.pretty_print()
|
|
52
|
-
|
|
53
|
-
lowered = jax.jit(fun).lower(*args)
|
|
54
|
-
compiled = lowered.compile()
|
|
55
|
-
memory_analysis = self._read_memory_analysis(
|
|
56
|
-
compiled.memory_analysis())
|
|
57
|
-
cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
|
|
58
|
-
|
|
59
|
-
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
60
|
-
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
61
|
-
self.profiling_data["FLOPS"] = cost_analysis
|
|
62
|
-
self.profiling_data[
|
|
63
|
-
"generated_code"] = memory_analysis[0]
|
|
64
|
-
self.profiling_data[
|
|
65
|
-
"argument_size"] = memory_analysis[0]
|
|
66
|
-
self.profiling_data[
|
|
67
|
-
"output_size"] = memory_analysis[0]
|
|
68
|
-
self.profiling_data["temp_size"] = memory_analysis[0]
|
|
69
|
-
return out
|
|
70
|
-
|
|
71
|
-
def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
|
|
72
|
-
start = time.perf_counter()
|
|
73
|
-
out = fun(*args)
|
|
74
|
-
if ndarray_arg is None:
|
|
75
|
-
out.block_until_ready()
|
|
76
|
-
else:
|
|
77
|
-
out[ndarray_arg].block_until_ready()
|
|
78
|
-
end = time.perf_counter()
|
|
79
|
-
self.times.append((end - start) * 1e3)
|
|
80
|
-
return out
|
|
81
|
-
|
|
82
|
-
def _get_mean_times(self) -> np.ndarray:
|
|
83
|
-
if jax.device_count() == 1:
|
|
84
|
-
return np.array(self.times)
|
|
85
|
-
|
|
86
|
-
devices = mesh_utils.create_device_mesh((jax.device_count(), ))
|
|
87
|
-
mesh = Mesh(devices, ('x', ))
|
|
88
|
-
sharding = NamedSharding(mesh, P('x'))
|
|
89
|
-
|
|
90
|
-
times_array = jnp.array(self.times)
|
|
91
|
-
global_shape = (jax.device_count(), times_array.shape[0])
|
|
92
|
-
global_times = jax.make_array_from_callback(
|
|
93
|
-
shape=global_shape,
|
|
94
|
-
sharding=sharding,
|
|
95
|
-
data_callback=lambda x: times_array)
|
|
96
|
-
|
|
97
|
-
@partial(shard_map,
|
|
98
|
-
mesh=mesh,
|
|
99
|
-
in_specs=P('x'),
|
|
100
|
-
out_specs=P(),
|
|
101
|
-
check_rep=False)
|
|
102
|
-
def get_mean_times(times):
|
|
103
|
-
return jax.lax.pmean(times, axis_name='x')
|
|
104
|
-
|
|
105
|
-
times_array = get_mean_times(global_times)
|
|
106
|
-
times_array.block_until_ready()
|
|
107
|
-
return np.array(times_array.addressable_data(0))
|
|
108
|
-
|
|
109
|
-
def report(self,
|
|
110
|
-
csv_filename: str,
|
|
111
|
-
function: str,
|
|
112
|
-
x: int,
|
|
113
|
-
y: int | None = None,
|
|
114
|
-
z: int | None = None,
|
|
115
|
-
precision: str = "float32",
|
|
116
|
-
px: int = 1,
|
|
117
|
-
py: int = 1,
|
|
118
|
-
backend: str = "NCCL",
|
|
119
|
-
nodes: int = 1,
|
|
120
|
-
md_filename: str | None = None,
|
|
121
|
-
extra_info: dict = {}):
|
|
122
|
-
|
|
123
|
-
if md_filename is None:
|
|
124
|
-
dirname, filename = os.path.dirname(
|
|
125
|
-
csv_filename), os.path.splitext(
|
|
126
|
-
os.path.basename(csv_filename))[0]
|
|
127
|
-
report_folder = filename if dirname == "" else f"{dirname}/{filename}"
|
|
128
|
-
os.makedirs(report_folder, exist_ok=True)
|
|
129
|
-
md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
|
|
130
|
-
|
|
131
|
-
y = x if y is None else y
|
|
132
|
-
z = x if z is None else z
|
|
133
|
-
|
|
134
|
-
times_array = self._get_mean_times()
|
|
135
|
-
|
|
136
|
-
min_time = np.min(times_array)
|
|
137
|
-
max_time = np.max(times_array)
|
|
138
|
-
mean_time = np.mean(times_array)
|
|
139
|
-
std_time = np.std(times_array)
|
|
140
|
-
last_time = times_array[-1]
|
|
141
|
-
|
|
142
|
-
flops = self.profiling_data["FLOPS"]
|
|
143
|
-
generated_code = self.profiling_data["generated_code"]
|
|
144
|
-
argument_size = self.profiling_data["argument_size"]
|
|
145
|
-
output_size = self.profiling_data["output_size"]
|
|
146
|
-
temp_size = self.profiling_data["temp_size"]
|
|
147
|
-
|
|
148
|
-
csv_line = (
|
|
149
|
-
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
150
|
-
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
151
|
-
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
with open(csv_filename, 'a') as f:
|
|
155
|
-
f.write(csv_line)
|
|
156
|
-
|
|
157
|
-
param_dict = {
|
|
158
|
-
"Function": function,
|
|
159
|
-
"Precision": precision,
|
|
160
|
-
"X": x,
|
|
161
|
-
"Y": y,
|
|
162
|
-
"Z": z,
|
|
163
|
-
"PX": px,
|
|
164
|
-
"PY": py,
|
|
165
|
-
"Backend": backend,
|
|
166
|
-
"Nodes": nodes,
|
|
167
|
-
}
|
|
168
|
-
param_dict.update(extra_info)
|
|
169
|
-
profiling_result = {
|
|
170
|
-
"JIT Time": self.jit_time,
|
|
171
|
-
"Min Time": min_time,
|
|
172
|
-
"Max Time": max_time,
|
|
173
|
-
"Mean Time": mean_time,
|
|
174
|
-
"Std Time": std_time,
|
|
175
|
-
"Last Time": last_time,
|
|
176
|
-
"Generated Code": generated_code,
|
|
177
|
-
"Argument Size": argument_size,
|
|
178
|
-
"Output Size": output_size,
|
|
179
|
-
"Temporary Size": temp_size,
|
|
180
|
-
"FLOPS": self.profiling_data["FLOPS"]
|
|
181
|
-
}
|
|
182
|
-
iteration_runs = {}
|
|
183
|
-
for i in range(len(times_array)):
|
|
184
|
-
iteration_runs[f"Run {i}"] = times_array[i]
|
|
185
|
-
|
|
186
|
-
with open(md_filename, 'w') as f:
|
|
187
|
-
f.write(f"# Reporting for {function}\n")
|
|
188
|
-
f.write(f"## Parameters\n")
|
|
189
|
-
f.write(
|
|
190
|
-
tabulate(param_dict.items(),
|
|
191
|
-
headers=["Parameter", "Value"],
|
|
192
|
-
tablefmt='github'))
|
|
193
|
-
f.write("\n---\n")
|
|
194
|
-
f.write(f"## Profiling Data\n")
|
|
195
|
-
f.write(
|
|
196
|
-
tabulate(profiling_result.items(),
|
|
197
|
-
headers=["Parameter", "Value"],
|
|
198
|
-
tablefmt='github'))
|
|
199
|
-
f.write("\n---\n")
|
|
200
|
-
f.write(f"## Iteration Runs\n")
|
|
201
|
-
f.write(
|
|
202
|
-
tabulate(iteration_runs.items(),
|
|
203
|
-
headers=["Iteration", "Time"],
|
|
204
|
-
tablefmt='github'))
|
|
205
|
-
f.write("\n---\n")
|
|
206
|
-
f.write(f"## Compiled Code\n")
|
|
207
|
-
f.write(f"```hlo\n")
|
|
208
|
-
f.write(self.compiled_code["COMPILED"])
|
|
209
|
-
f.write(f"\n```\n")
|
|
210
|
-
f.write("\n---\n")
|
|
211
|
-
f.write(f"## Lowered Code\n")
|
|
212
|
-
f.write(f"```hlo\n")
|
|
213
|
-
f.write(self.compiled_code["LOWERED"])
|
|
214
|
-
f.write(f"\n```\n")
|
|
215
|
-
f.write("\n---\n")
|
|
216
|
-
if self.save_jaxpr:
|
|
217
|
-
f.write(f"## JAXPR\n")
|
|
218
|
-
f.write(f"```haskel\n")
|
|
219
|
-
f.write(self.compiled_code["JAXPR"])
|
|
220
|
-
f.write(f"\n```\n")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|