jax-hpc-profiler 0.2.2__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/pyproject.toml +1 -1
  3. jax_hpc_profiler-0.2.5/src/jax_hpc_profiler/timer.py +225 -0
  4. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  5. jax_hpc_profiler-0.2.2/src/jax_hpc_profiler/timer.py +0 -220
  6. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/LICENSE +0 -0
  7. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/README.md +0 -0
  8. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/setup.cfg +0 -0
  9. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/__init__.py +0 -0
  10. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/create_argparse.py +0 -0
  11. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/main.py +0 -0
  12. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/plotting.py +0 -0
  13. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler/utils.py +0 -0
  14. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  15. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  16. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  17. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  18. {jax_hpc_profiler-0.2.2 → jax_hpc_profiler-0.2.5}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.2
3
+ Version: 0.2.5
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.2"
7
+ version = "0.2.5"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -0,0 +1,225 @@
1
+ import os
2
+ import time
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Tuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from jax import make_jaxpr
10
+ from jax.experimental import mesh_utils
11
+ from jax.experimental.shard_map import shard_map
12
+ from jax.sharding import Mesh, NamedSharding
13
+ from jax.sharding import PartitionSpec as P
14
+ from tabulate import tabulate
15
+
16
+
17
+ class Timer:
18
+
19
+ def __init__(self, save_jaxpr=False):
20
+ self.jit_time = None
21
+ self.times = []
22
+ self.profiling_data = {}
23
+ self.compiled_code = {}
24
+ self.save_jaxpr = save_jaxpr
25
+
26
+ def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
+ if cost_analysis is None:
28
+ return None
29
+ return cost_analysis[0]['flops']
30
+
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
+ factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
+ factor = int(np.log10(memory_analysis) // 3)
36
+
37
+ return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
+
39
+ def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
40
+ if memory_analysis is None:
41
+ return None, None, None, None
42
+ return (memory_analysis.generated_code_size_in_bytes,
43
+ memory_analysis.argument_size_in_bytes,
44
+ memory_analysis.output_size_in_bytes,
45
+ memory_analysis.temp_size_in_bytes)
46
+
47
+ def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
48
+ start = time.perf_counter()
49
+ out = fun(*args)
50
+ if ndarray_arg is None:
51
+ out.block_until_ready()
52
+ else:
53
+ out[ndarray_arg].block_until_ready()
54
+ end = time.perf_counter()
55
+ self.jit_time = (end - start) * 1e3
56
+
57
+ if self.save_jaxpr:
58
+ jaxpr = make_jaxpr(fun)(*args)
59
+ self.compiled_code["JAXPR"] = jaxpr.pretty_print()
60
+
61
+ lowered = jax.jit(fun).lower(*args)
62
+ compiled = lowered.compile()
63
+ memory_analysis = self._read_memory_analysis(
64
+ compiled.memory_analysis())
65
+ cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
66
+
67
+ self.compiled_code["LOWERED"] = lowered.as_text()
68
+ self.compiled_code["COMPILED"] = compiled.as_text()
69
+ self.profiling_data["FLOPS"] = cost_analysis
70
+ self.profiling_data["generated_code"] = memory_analysis[0]
71
+ self.profiling_data["argument_size"] = memory_analysis[1]
72
+ self.profiling_data["output_size"] = memory_analysis[2]
73
+ self.profiling_data["temp_size"] = memory_analysis[3]
74
+ return out
75
+
76
+ def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
77
+ start = time.perf_counter()
78
+ out = fun(*args)
79
+ if ndarray_arg is None:
80
+ out.block_until_ready()
81
+ else:
82
+ out[ndarray_arg].block_until_ready()
83
+ end = time.perf_counter()
84
+ self.times.append((end - start) * 1e3)
85
+ return out
86
+
87
+ def _get_mean_times(self) -> np.ndarray:
88
+ if jax.device_count() == 1:
89
+ return np.array(self.times)
90
+
91
+ devices = mesh_utils.create_device_mesh((jax.device_count(), ))
92
+ mesh = Mesh(devices, ('x', ))
93
+ sharding = NamedSharding(mesh, P('x'))
94
+
95
+ times_array = jnp.array(self.times)
96
+ global_shape = (jax.device_count(), times_array.shape[0])
97
+ global_times = jax.make_array_from_callback(
98
+ shape=global_shape,
99
+ sharding=sharding,
100
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
101
+
102
+ @partial(shard_map,
103
+ mesh=mesh,
104
+ in_specs=P('x'),
105
+ out_specs=P(),
106
+ check_rep=False)
107
+ def get_mean_times(times):
108
+ return jax.lax.pmean(times, axis_name='x')
109
+
110
+ times_array = get_mean_times(global_times)
111
+ times_array.block_until_ready()
112
+ return np.array(times_array.addressable_data(0)[0])
113
+
114
+ def report(self,
115
+ csv_filename: str,
116
+ function: str,
117
+ x: int,
118
+ y: int | None = None,
119
+ z: int | None = None,
120
+ precision: str = "float32",
121
+ px: int = 1,
122
+ py: int = 1,
123
+ backend: str = "NCCL",
124
+ nodes: int = 1,
125
+ md_filename: str | None = None,
126
+ extra_info: dict = {}):
127
+
128
+ if md_filename is None:
129
+ dirname, filename = os.path.dirname(
130
+ csv_filename), os.path.splitext(
131
+ os.path.basename(csv_filename))[0]
132
+ report_folder = filename if dirname == "" else f"{dirname}/{filename}"
133
+ os.makedirs(report_folder, exist_ok=True)
134
+ md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
135
+
136
+ y = x if y is None else y
137
+ z = x if z is None else z
138
+
139
+ times_array = self._get_mean_times()
140
+ if jax.process_index() == 0:
141
+ min_time = np.min(times_array)
142
+ max_time = np.max(times_array)
143
+ mean_time = np.mean(times_array)
144
+ std_time = np.std(times_array)
145
+ last_time = times_array[-1]
146
+
147
+ flops = self.profiling_data["FLOPS"]
148
+ generated_code = self.profiling_data["generated_code"]
149
+ argument_size = self.profiling_data["argument_size"]
150
+ output_size = self.profiling_data["output_size"]
151
+ temp_size = self.profiling_data["temp_size"]
152
+
153
+ csv_line = (
154
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
155
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
156
+ f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
157
+ )
158
+
159
+ with open(csv_filename, 'a') as f:
160
+ f.write(csv_line)
161
+
162
+ param_dict = {
163
+ "Function": function,
164
+ "Precision": precision,
165
+ "X": x,
166
+ "Y": y,
167
+ "Z": z,
168
+ "PX": px,
169
+ "PY": py,
170
+ "Backend": backend,
171
+ "Nodes": nodes,
172
+ }
173
+ param_dict.update(extra_info)
174
+ profiling_result = {
175
+ "JIT Time": self.jit_time,
176
+ "Min Time": min_time,
177
+ "Max Time": max_time,
178
+ "Mean Time": mean_time,
179
+ "Std Time": std_time,
180
+ "Last Time": last_time,
181
+ "Generated Code": self._normalize_memory_units(generated_code),
182
+ "Argument Size": self._normalize_memory_units(argument_size),
183
+ "Output Size": self._normalize_memory_units(output_size),
184
+ "Temporary Size": self._normalize_memory_units(temp_size),
185
+ "FLOPS": self.profiling_data["FLOPS"]
186
+ }
187
+ iteration_runs = {}
188
+ for i in range(len(times_array)):
189
+ iteration_runs[f"Run {i}"] = times_array[i]
190
+
191
+ with open(md_filename, 'w') as f:
192
+ f.write(f"# Reporting for {function}\n")
193
+ f.write(f"## Parameters\n")
194
+ f.write(
195
+ tabulate(param_dict.items(),
196
+ headers=["Parameter", "Value"],
197
+ tablefmt='github'))
198
+ f.write("\n---\n")
199
+ f.write(f"## Profiling Data\n")
200
+ f.write(
201
+ tabulate(profiling_result.items(),
202
+ headers=["Parameter", "Value"],
203
+ tablefmt='github'))
204
+ f.write("\n---\n")
205
+ f.write(f"## Iteration Runs\n")
206
+ f.write(
207
+ tabulate(iteration_runs.items(),
208
+ headers=["Iteration", "Time"],
209
+ tablefmt='github'))
210
+ f.write("\n---\n")
211
+ f.write(f"## Compiled Code\n")
212
+ f.write(f"```hlo\n")
213
+ f.write(self.compiled_code["COMPILED"])
214
+ f.write(f"\n```\n")
215
+ f.write("\n---\n")
216
+ f.write(f"## Lowered Code\n")
217
+ f.write(f"```hlo\n")
218
+ f.write(self.compiled_code["LOWERED"])
219
+ f.write(f"\n```\n")
220
+ f.write("\n---\n")
221
+ if self.save_jaxpr:
222
+ f.write(f"## JAXPR\n")
223
+ f.write(f"```haskel\n")
224
+ f.write(self.compiled_code["JAXPR"])
225
+ f.write(f"\n```\n")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.2
3
+ Version: 0.2.5
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,220 +0,0 @@
1
- import os
2
- import time
3
- from functools import partial
4
- from typing import Any, Callable, List, Tuple
5
-
6
- import jax
7
- import jax.numpy as jnp
8
- import numpy as np
9
- from jax import make_jaxpr
10
- from jax.experimental import mesh_utils
11
- from jax.experimental.shard_map import shard_map
12
- from jax.sharding import Mesh, NamedSharding
13
- from jax.sharding import PartitionSpec as P
14
- from tabulate import tabulate
15
-
16
-
17
- class Timer:
18
-
19
- def __init__(self, save_jaxpr=False):
20
- self.jit_time = None
21
- self.times = []
22
- self.profiling_data = {}
23
- self.compiled_code = {}
24
- self.save_jaxpr = save_jaxpr
25
-
26
- def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
- if cost_analysis is None:
28
- return None
29
- return cost_analysis[0]['flops']
30
-
31
- def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
32
- if memory_analysis is None:
33
- return None, None, None, None
34
- return (memory_analysis.generated_code_size_in_bytes,
35
- memory_analysis.argument_size_in_bytes,
36
- memory_analysis.output_size_in_bytes,
37
- memory_analysis.temp_size_in_bytes)
38
-
39
- def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
40
- start = time.perf_counter()
41
- out = jax.jit(fun)(*args)
42
- if ndarray_arg is None:
43
- out.block_until_ready()
44
- else:
45
- out[ndarray_arg].block_until_ready()
46
- end = time.perf_counter()
47
- self.jit_time = (end - start) * 1e3
48
-
49
- if self.save_jaxpr:
50
- jaxpr = make_jaxpr(fun)(*args)
51
- self.compiled_code["JAXPR"] = jaxpr.pretty_print()
52
-
53
- lowered = jax.jit(fun).lower(*args)
54
- compiled = lowered.compile()
55
- memory_analysis = self._read_memory_analysis(
56
- compiled.memory_analysis())
57
- cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
58
-
59
- self.compiled_code["LOWERED"] = lowered.as_text()
60
- self.compiled_code["COMPILED"] = compiled.as_text()
61
- self.profiling_data["FLOPS"] = cost_analysis
62
- self.profiling_data[
63
- "generated_code"] = memory_analysis[0]
64
- self.profiling_data[
65
- "argument_size"] = memory_analysis[0]
66
- self.profiling_data[
67
- "output_size"] = memory_analysis[0]
68
- self.profiling_data["temp_size"] = memory_analysis[0]
69
- return out
70
-
71
- def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
72
- start = time.perf_counter()
73
- out = fun(*args)
74
- if ndarray_arg is None:
75
- out.block_until_ready()
76
- else:
77
- out[ndarray_arg].block_until_ready()
78
- end = time.perf_counter()
79
- self.times.append((end - start) * 1e3)
80
- return out
81
-
82
- def _get_mean_times(self) -> np.ndarray:
83
- if jax.device_count() == 1:
84
- return np.array(self.times)
85
-
86
- devices = mesh_utils.create_device_mesh((jax.device_count(), ))
87
- mesh = Mesh(devices, ('x', ))
88
- sharding = NamedSharding(mesh, P('x'))
89
-
90
- times_array = jnp.array(self.times)
91
- global_shape = (jax.device_count(), times_array.shape[0])
92
- global_times = jax.make_array_from_callback(
93
- shape=global_shape,
94
- sharding=sharding,
95
- data_callback=lambda x: times_array)
96
-
97
- @partial(shard_map,
98
- mesh=mesh,
99
- in_specs=P('x'),
100
- out_specs=P(),
101
- check_rep=False)
102
- def get_mean_times(times):
103
- return jax.lax.pmean(times, axis_name='x')
104
-
105
- times_array = get_mean_times(global_times)
106
- times_array.block_until_ready()
107
- return np.array(times_array.addressable_data(0))
108
-
109
- def report(self,
110
- csv_filename: str,
111
- function: str,
112
- x: int,
113
- y: int | None = None,
114
- z: int | None = None,
115
- precision: str = "float32",
116
- px: int = 1,
117
- py: int = 1,
118
- backend: str = "NCCL",
119
- nodes: int = 1,
120
- md_filename: str | None = None,
121
- extra_info: dict = {}):
122
-
123
- if md_filename is None:
124
- dirname, filename = os.path.dirname(
125
- csv_filename), os.path.splitext(
126
- os.path.basename(csv_filename))[0]
127
- report_folder = filename if dirname == "" else f"{dirname}/{filename}"
128
- os.makedirs(report_folder, exist_ok=True)
129
- md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
130
-
131
- y = x if y is None else y
132
- z = x if z is None else z
133
-
134
- times_array = self._get_mean_times()
135
-
136
- min_time = np.min(times_array)
137
- max_time = np.max(times_array)
138
- mean_time = np.mean(times_array)
139
- std_time = np.std(times_array)
140
- last_time = times_array[-1]
141
-
142
- flops = self.profiling_data["FLOPS"]
143
- generated_code = self.profiling_data["generated_code"]
144
- argument_size = self.profiling_data["argument_size"]
145
- output_size = self.profiling_data["output_size"]
146
- temp_size = self.profiling_data["temp_size"]
147
-
148
- csv_line = (
149
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
150
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
151
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
152
- )
153
-
154
- with open(csv_filename, 'a') as f:
155
- f.write(csv_line)
156
-
157
- param_dict = {
158
- "Function": function,
159
- "Precision": precision,
160
- "X": x,
161
- "Y": y,
162
- "Z": z,
163
- "PX": px,
164
- "PY": py,
165
- "Backend": backend,
166
- "Nodes": nodes,
167
- }
168
- param_dict.update(extra_info)
169
- profiling_result = {
170
- "JIT Time": self.jit_time,
171
- "Min Time": min_time,
172
- "Max Time": max_time,
173
- "Mean Time": mean_time,
174
- "Std Time": std_time,
175
- "Last Time": last_time,
176
- "Generated Code": generated_code,
177
- "Argument Size": argument_size,
178
- "Output Size": output_size,
179
- "Temporary Size": temp_size,
180
- "FLOPS": self.profiling_data["FLOPS"]
181
- }
182
- iteration_runs = {}
183
- for i in range(len(times_array)):
184
- iteration_runs[f"Run {i}"] = times_array[i]
185
-
186
- with open(md_filename, 'w') as f:
187
- f.write(f"# Reporting for {function}\n")
188
- f.write(f"## Parameters\n")
189
- f.write(
190
- tabulate(param_dict.items(),
191
- headers=["Parameter", "Value"],
192
- tablefmt='github'))
193
- f.write("\n---\n")
194
- f.write(f"## Profiling Data\n")
195
- f.write(
196
- tabulate(profiling_result.items(),
197
- headers=["Parameter", "Value"],
198
- tablefmt='github'))
199
- f.write("\n---\n")
200
- f.write(f"## Iteration Runs\n")
201
- f.write(
202
- tabulate(iteration_runs.items(),
203
- headers=["Iteration", "Time"],
204
- tablefmt='github'))
205
- f.write("\n---\n")
206
- f.write(f"## Compiled Code\n")
207
- f.write(f"```hlo\n")
208
- f.write(self.compiled_code["COMPILED"])
209
- f.write(f"\n```\n")
210
- f.write("\n---\n")
211
- f.write(f"## Lowered Code\n")
212
- f.write(f"```hlo\n")
213
- f.write(self.compiled_code["LOWERED"])
214
- f.write(f"\n```\n")
215
- f.write("\n---\n")
216
- if self.save_jaxpr:
217
- f.write(f"## JAXPR\n")
218
- f.write(f"```haskel\n")
219
- f.write(self.compiled_code["JAXPR"])
220
- f.write(f"\n```\n")