jax-hpc-profiler 0.2.3__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/pyproject.toml +1 -1
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/timer.py +94 -99
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/LICENSE +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/README.md +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/setup.cfg +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/__init__.py +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/create_argparse.py +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/main.py +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/plotting.py +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/utils.py +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
- {jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
|
@@ -28,15 +28,14 @@ class Timer:
|
|
|
28
28
|
return None
|
|
29
29
|
return cost_analysis[0]['flops']
|
|
30
30
|
|
|
31
|
-
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
-
|
|
33
|
-
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB'
|
|
34
|
-
factors = [1
|
|
31
|
+
def _normalize_memory_units(self, memory_analysis) -> str:
|
|
32
|
+
|
|
33
|
+
sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
|
34
|
+
factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
|
|
35
35
|
factor = int(np.log10(memory_analysis) // 3)
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
|
|
38
38
|
|
|
39
|
-
|
|
40
39
|
def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
|
|
41
40
|
if memory_analysis is None:
|
|
42
41
|
return None, None, None, None
|
|
@@ -68,12 +67,9 @@ class Timer:
|
|
|
68
67
|
self.compiled_code["LOWERED"] = lowered.as_text()
|
|
69
68
|
self.compiled_code["COMPILED"] = compiled.as_text()
|
|
70
69
|
self.profiling_data["FLOPS"] = cost_analysis
|
|
71
|
-
self.profiling_data[
|
|
72
|
-
|
|
73
|
-
self.profiling_data[
|
|
74
|
-
"argument_size"] = memory_analysis[1]
|
|
75
|
-
self.profiling_data[
|
|
76
|
-
"output_size"] = memory_analysis[2]
|
|
70
|
+
self.profiling_data["generated_code"] = memory_analysis[0]
|
|
71
|
+
self.profiling_data["argument_size"] = memory_analysis[1]
|
|
72
|
+
self.profiling_data["output_size"] = memory_analysis[2]
|
|
77
73
|
self.profiling_data["temp_size"] = memory_analysis[3]
|
|
78
74
|
return out
|
|
79
75
|
|
|
@@ -101,7 +97,7 @@ class Timer:
|
|
|
101
97
|
global_times = jax.make_array_from_callback(
|
|
102
98
|
shape=global_shape,
|
|
103
99
|
sharding=sharding,
|
|
104
|
-
data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
|
|
100
|
+
data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
|
|
105
101
|
|
|
106
102
|
@partial(shard_map,
|
|
107
103
|
mesh=mesh,
|
|
@@ -141,90 +137,89 @@ class Timer:
|
|
|
141
137
|
z = x if z is None else z
|
|
142
138
|
|
|
143
139
|
times_array = self._get_mean_times()
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
f.write(f"\n```\n")
|
|
219
|
-
f.write("\n---\n")
|
|
220
|
-
f.write(f"## Lowered Code\n")
|
|
221
|
-
f.write(f"```hlo\n")
|
|
222
|
-
f.write(self.compiled_code["LOWERED"])
|
|
223
|
-
f.write(f"\n```\n")
|
|
224
|
-
f.write("\n---\n")
|
|
225
|
-
if self.save_jaxpr:
|
|
226
|
-
f.write(f"## JAXPR\n")
|
|
227
|
-
f.write(f"```haskel\n")
|
|
228
|
-
f.write(self.compiled_code["JAXPR"])
|
|
140
|
+
if jax.process_index() == 0:
|
|
141
|
+
min_time = np.min(times_array)
|
|
142
|
+
max_time = np.max(times_array)
|
|
143
|
+
mean_time = np.mean(times_array)
|
|
144
|
+
std_time = np.std(times_array)
|
|
145
|
+
last_time = times_array[-1]
|
|
146
|
+
|
|
147
|
+
flops = self.profiling_data["FLOPS"]
|
|
148
|
+
generated_code = self.profiling_data["generated_code"]
|
|
149
|
+
argument_size = self.profiling_data["argument_size"]
|
|
150
|
+
output_size = self.profiling_data["output_size"]
|
|
151
|
+
temp_size = self.profiling_data["temp_size"]
|
|
152
|
+
|
|
153
|
+
csv_line = (
|
|
154
|
+
f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
|
|
155
|
+
f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
|
|
156
|
+
f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
with open(csv_filename, 'a') as f:
|
|
160
|
+
f.write(csv_line)
|
|
161
|
+
|
|
162
|
+
param_dict = {
|
|
163
|
+
"Function": function,
|
|
164
|
+
"Precision": precision,
|
|
165
|
+
"X": x,
|
|
166
|
+
"Y": y,
|
|
167
|
+
"Z": z,
|
|
168
|
+
"PX": px,
|
|
169
|
+
"PY": py,
|
|
170
|
+
"Backend": backend,
|
|
171
|
+
"Nodes": nodes,
|
|
172
|
+
}
|
|
173
|
+
param_dict.update(extra_info)
|
|
174
|
+
profiling_result = {
|
|
175
|
+
"JIT Time": self.jit_time,
|
|
176
|
+
"Min Time": min_time,
|
|
177
|
+
"Max Time": max_time,
|
|
178
|
+
"Mean Time": mean_time,
|
|
179
|
+
"Std Time": std_time,
|
|
180
|
+
"Last Time": last_time,
|
|
181
|
+
"Generated Code": self._normalize_memory_units(generated_code),
|
|
182
|
+
"Argument Size": self._normalize_memory_units(argument_size),
|
|
183
|
+
"Output Size": self._normalize_memory_units(output_size),
|
|
184
|
+
"Temporary Size": self._normalize_memory_units(temp_size),
|
|
185
|
+
"FLOPS": self.profiling_data["FLOPS"]
|
|
186
|
+
}
|
|
187
|
+
iteration_runs = {}
|
|
188
|
+
for i in range(len(times_array)):
|
|
189
|
+
iteration_runs[f"Run {i}"] = times_array[i]
|
|
190
|
+
|
|
191
|
+
with open(md_filename, 'w') as f:
|
|
192
|
+
f.write(f"# Reporting for {function}\n")
|
|
193
|
+
f.write(f"## Parameters\n")
|
|
194
|
+
f.write(
|
|
195
|
+
tabulate(param_dict.items(),
|
|
196
|
+
headers=["Parameter", "Value"],
|
|
197
|
+
tablefmt='github'))
|
|
198
|
+
f.write("\n---\n")
|
|
199
|
+
f.write(f"## Profiling Data\n")
|
|
200
|
+
f.write(
|
|
201
|
+
tabulate(profiling_result.items(),
|
|
202
|
+
headers=["Parameter", "Value"],
|
|
203
|
+
tablefmt='github'))
|
|
204
|
+
f.write("\n---\n")
|
|
205
|
+
f.write(f"## Iteration Runs\n")
|
|
206
|
+
f.write(
|
|
207
|
+
tabulate(iteration_runs.items(),
|
|
208
|
+
headers=["Iteration", "Time"],
|
|
209
|
+
tablefmt='github'))
|
|
210
|
+
f.write("\n---\n")
|
|
211
|
+
f.write(f"## Compiled Code\n")
|
|
212
|
+
f.write(f"```hlo\n")
|
|
213
|
+
f.write(self.compiled_code["COMPILED"])
|
|
229
214
|
f.write(f"\n```\n")
|
|
230
|
-
|
|
215
|
+
f.write("\n---\n")
|
|
216
|
+
f.write(f"## Lowered Code\n")
|
|
217
|
+
f.write(f"```hlo\n")
|
|
218
|
+
f.write(self.compiled_code["LOWERED"])
|
|
219
|
+
f.write(f"\n```\n")
|
|
220
|
+
f.write("\n---\n")
|
|
221
|
+
if self.save_jaxpr:
|
|
222
|
+
f.write(f"## JAXPR\n")
|
|
223
|
+
f.write(f"```haskel\n")
|
|
224
|
+
f.write(self.compiled_code["JAXPR"])
|
|
225
|
+
f.write(f"\n```\n")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/requires.txt
RENAMED
|
File without changes
|
{jax_hpc_profiler-0.2.3 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/top_level.txt
RENAMED
|
File without changes
|