jax-hpc-profiler 0.2.2__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -28,6 +28,14 @@ class Timer:
28
28
  return None
29
29
  return cost_analysis[0]['flops']
30
30
 
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
+ factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
+ factor = int(np.log10(memory_analysis) // 3)
36
+
37
+ return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
+
31
39
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
32
40
  if memory_analysis is None:
33
41
  return None, None, None, None
@@ -38,7 +46,7 @@ class Timer:
38
46
 
39
47
  def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
40
48
  start = time.perf_counter()
41
- out = jax.jit(fun)(*args)
49
+ out = fun(*args)
42
50
  if ndarray_arg is None:
43
51
  out.block_until_ready()
44
52
  else:
@@ -59,13 +67,10 @@ class Timer:
59
67
  self.compiled_code["LOWERED"] = lowered.as_text()
60
68
  self.compiled_code["COMPILED"] = compiled.as_text()
61
69
  self.profiling_data["FLOPS"] = cost_analysis
62
- self.profiling_data[
63
- "generated_code"] = memory_analysis[0]
64
- self.profiling_data[
65
- "argument_size"] = memory_analysis[0]
66
- self.profiling_data[
67
- "output_size"] = memory_analysis[0]
68
- self.profiling_data["temp_size"] = memory_analysis[0]
70
+ self.profiling_data["generated_code"] = memory_analysis[0]
71
+ self.profiling_data["argument_size"] = memory_analysis[1]
72
+ self.profiling_data["output_size"] = memory_analysis[2]
73
+ self.profiling_data["temp_size"] = memory_analysis[3]
69
74
  return out
70
75
 
71
76
  def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
@@ -92,7 +97,7 @@ class Timer:
92
97
  global_times = jax.make_array_from_callback(
93
98
  shape=global_shape,
94
99
  sharding=sharding,
95
- data_callback=lambda x: times_array)
100
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
96
101
 
97
102
  @partial(shard_map,
98
103
  mesh=mesh,
@@ -104,7 +109,7 @@ class Timer:
104
109
 
105
110
  times_array = get_mean_times(global_times)
106
111
  times_array.block_until_ready()
107
- return np.array(times_array.addressable_data(0))
112
+ return np.array(times_array.addressable_data(0)[0])
108
113
 
109
114
  def report(self,
110
115
  csv_filename: str,
@@ -132,89 +137,89 @@ class Timer:
132
137
  z = x if z is None else z
133
138
 
134
139
  times_array = self._get_mean_times()
135
-
136
- min_time = np.min(times_array)
137
- max_time = np.max(times_array)
138
- mean_time = np.mean(times_array)
139
- std_time = np.std(times_array)
140
- last_time = times_array[-1]
141
-
142
- flops = self.profiling_data["FLOPS"]
143
- generated_code = self.profiling_data["generated_code"]
144
- argument_size = self.profiling_data["argument_size"]
145
- output_size = self.profiling_data["output_size"]
146
- temp_size = self.profiling_data["temp_size"]
147
-
148
- csv_line = (
149
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
150
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
151
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
152
- )
153
-
154
- with open(csv_filename, 'a') as f:
155
- f.write(csv_line)
156
-
157
- param_dict = {
158
- "Function": function,
159
- "Precision": precision,
160
- "X": x,
161
- "Y": y,
162
- "Z": z,
163
- "PX": px,
164
- "PY": py,
165
- "Backend": backend,
166
- "Nodes": nodes,
167
- }
168
- param_dict.update(extra_info)
169
- profiling_result = {
170
- "JIT Time": self.jit_time,
171
- "Min Time": min_time,
172
- "Max Time": max_time,
173
- "Mean Time": mean_time,
174
- "Std Time": std_time,
175
- "Last Time": last_time,
176
- "Generated Code": generated_code,
177
- "Argument Size": argument_size,
178
- "Output Size": output_size,
179
- "Temporary Size": temp_size,
180
- "FLOPS": self.profiling_data["FLOPS"]
181
- }
182
- iteration_runs = {}
183
- for i in range(len(times_array)):
184
- iteration_runs[f"Run {i}"] = times_array[i]
185
-
186
- with open(md_filename, 'w') as f:
187
- f.write(f"# Reporting for {function}\n")
188
- f.write(f"## Parameters\n")
189
- f.write(
190
- tabulate(param_dict.items(),
191
- headers=["Parameter", "Value"],
192
- tablefmt='github'))
193
- f.write("\n---\n")
194
- f.write(f"## Profiling Data\n")
195
- f.write(
196
- tabulate(profiling_result.items(),
197
- headers=["Parameter", "Value"],
198
- tablefmt='github'))
199
- f.write("\n---\n")
200
- f.write(f"## Iteration Runs\n")
201
- f.write(
202
- tabulate(iteration_runs.items(),
203
- headers=["Iteration", "Time"],
204
- tablefmt='github'))
205
- f.write("\n---\n")
206
- f.write(f"## Compiled Code\n")
207
- f.write(f"```hlo\n")
208
- f.write(self.compiled_code["COMPILED"])
209
- f.write(f"\n```\n")
210
- f.write("\n---\n")
211
- f.write(f"## Lowered Code\n")
212
- f.write(f"```hlo\n")
213
- f.write(self.compiled_code["LOWERED"])
214
- f.write(f"\n```\n")
215
- f.write("\n---\n")
216
- if self.save_jaxpr:
217
- f.write(f"## JAXPR\n")
218
- f.write(f"```haskel\n")
219
- f.write(self.compiled_code["JAXPR"])
140
+ if jax.process_index() == 0:
141
+ min_time = np.min(times_array)
142
+ max_time = np.max(times_array)
143
+ mean_time = np.mean(times_array)
144
+ std_time = np.std(times_array)
145
+ last_time = times_array[-1]
146
+
147
+ flops = self.profiling_data["FLOPS"]
148
+ generated_code = self.profiling_data["generated_code"]
149
+ argument_size = self.profiling_data["argument_size"]
150
+ output_size = self.profiling_data["output_size"]
151
+ temp_size = self.profiling_data["temp_size"]
152
+
153
+ csv_line = (
154
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
155
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
156
+ f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
157
+ )
158
+
159
+ with open(csv_filename, 'a') as f:
160
+ f.write(csv_line)
161
+
162
+ param_dict = {
163
+ "Function": function,
164
+ "Precision": precision,
165
+ "X": x,
166
+ "Y": y,
167
+ "Z": z,
168
+ "PX": px,
169
+ "PY": py,
170
+ "Backend": backend,
171
+ "Nodes": nodes,
172
+ }
173
+ param_dict.update(extra_info)
174
+ profiling_result = {
175
+ "JIT Time": self.jit_time,
176
+ "Min Time": min_time,
177
+ "Max Time": max_time,
178
+ "Mean Time": mean_time,
179
+ "Std Time": std_time,
180
+ "Last Time": last_time,
181
+ "Generated Code": self._normalize_memory_units(generated_code),
182
+ "Argument Size": self._normalize_memory_units(argument_size),
183
+ "Output Size": self._normalize_memory_units(output_size),
184
+ "Temporary Size": self._normalize_memory_units(temp_size),
185
+ "FLOPS": self.profiling_data["FLOPS"]
186
+ }
187
+ iteration_runs = {}
188
+ for i in range(len(times_array)):
189
+ iteration_runs[f"Run {i}"] = times_array[i]
190
+
191
+ with open(md_filename, 'w') as f:
192
+ f.write(f"# Reporting for {function}\n")
193
+ f.write(f"## Parameters\n")
194
+ f.write(
195
+ tabulate(param_dict.items(),
196
+ headers=["Parameter", "Value"],
197
+ tablefmt='github'))
198
+ f.write("\n---\n")
199
+ f.write(f"## Profiling Data\n")
200
+ f.write(
201
+ tabulate(profiling_result.items(),
202
+ headers=["Parameter", "Value"],
203
+ tablefmt='github'))
204
+ f.write("\n---\n")
205
+ f.write(f"## Iteration Runs\n")
206
+ f.write(
207
+ tabulate(iteration_runs.items(),
208
+ headers=["Iteration", "Time"],
209
+ tablefmt='github'))
210
+ f.write("\n---\n")
211
+ f.write(f"## Compiled Code\n")
212
+ f.write(f"```hlo\n")
213
+ f.write(self.compiled_code["COMPILED"])
214
+ f.write(f"\n```\n")
215
+ f.write("\n---\n")
216
+ f.write(f"## Lowered Code\n")
217
+ f.write(f"```hlo\n")
218
+ f.write(self.compiled_code["LOWERED"])
220
219
  f.write(f"\n```\n")
220
+ f.write("\n---\n")
221
+ if self.save_jaxpr:
222
+ f.write(f"## JAXPR\n")
223
+ f.write(f"```haskel\n")
224
+ f.write(self.compiled_code["JAXPR"])
225
+ f.write(f"\n```\n")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.2
3
+ Version: 0.2.5
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
2
2
  jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
3
  jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
4
  jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=BhwWrdPIdZHb2KCpyWYXXNQyBxl4QU3XEmfft5FA1Vc,7843
5
+ jax_hpc_profiler/timer.py,sha256=r4Mw2tC82cxvkMPkIy8BuZjKikgxn6cviEgmu6rpC9o,8616
6
6
  jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.2.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.2.dist-info/METADATA,sha256=jrZBe0YAZIRuAcgSJryVccu5Rgk7z5G6UwplaEBPNb0,49250
9
- jax_hpc_profiler-0.2.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.2.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.2.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.2.dist-info/RECORD,,
7
+ jax_hpc_profiler-0.2.5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.5.dist-info/METADATA,sha256=6Vk6fA1nz-m8ZZVzSZPsg9xR9iJkFJ01x36pddG0RAM,49250
9
+ jax_hpc_profiler-0.2.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.5.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.5.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.5.dist-info/RECORD,,