jax-hpc-profiler 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -28,15 +28,14 @@ class Timer:
28
28
  return None
29
29
  return cost_analysis[0]['flops']
30
30
 
31
- def _normalize_memory_units(self, memory_analysis) -> str:
32
-
33
- sizes_str = ['B', 'KB', 'MB', 'GB', 'TB' , 'PB']
34
- factors = [1 , 1024 , 1024**2 , 1024**3 , 1024**4 , 1024**5]
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
+ factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
35
  factor = int(np.log10(memory_analysis) // 3)
36
-
36
+
37
37
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
38
 
39
-
40
39
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
41
40
  if memory_analysis is None:
42
41
  return None, None, None, None
@@ -68,12 +67,9 @@ class Timer:
68
67
  self.compiled_code["LOWERED"] = lowered.as_text()
69
68
  self.compiled_code["COMPILED"] = compiled.as_text()
70
69
  self.profiling_data["FLOPS"] = cost_analysis
71
- self.profiling_data[
72
- "generated_code"] = memory_analysis[0]
73
- self.profiling_data[
74
- "argument_size"] = memory_analysis[1]
75
- self.profiling_data[
76
- "output_size"] = memory_analysis[2]
70
+ self.profiling_data["generated_code"] = memory_analysis[0]
71
+ self.profiling_data["argument_size"] = memory_analysis[1]
72
+ self.profiling_data["output_size"] = memory_analysis[2]
77
73
  self.profiling_data["temp_size"] = memory_analysis[3]
78
74
  return out
79
75
 
@@ -101,7 +97,7 @@ class Timer:
101
97
  global_times = jax.make_array_from_callback(
102
98
  shape=global_shape,
103
99
  sharding=sharding,
104
- data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
100
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
105
101
 
106
102
  @partial(shard_map,
107
103
  mesh=mesh,
@@ -141,90 +137,89 @@ class Timer:
141
137
  z = x if z is None else z
142
138
 
143
139
  times_array = self._get_mean_times()
144
- min_time = np.min(times_array)
145
- max_time = np.max(times_array)
146
- mean_time = np.mean(times_array)
147
- std_time = np.std(times_array)
148
- last_time = times_array[-1]
149
-
150
-
151
- flops = self.profiling_data["FLOPS"]
152
- generated_code = self.profiling_data["generated_code"]
153
- argument_size = self.profiling_data["argument_size"]
154
- output_size = self.profiling_data["output_size"]
155
- temp_size = self.profiling_data["temp_size"]
156
-
157
- csv_line = (
158
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
159
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
160
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
161
- )
162
-
163
- with open(csv_filename, 'a') as f:
164
- f.write(csv_line)
165
-
166
- param_dict = {
167
- "Function": function,
168
- "Precision": precision,
169
- "X": x,
170
- "Y": y,
171
- "Z": z,
172
- "PX": px,
173
- "PY": py,
174
- "Backend": backend,
175
- "Nodes": nodes,
176
- }
177
- param_dict.update(extra_info)
178
- profiling_result = {
179
- "JIT Time": self.jit_time,
180
- "Min Time": min_time,
181
- "Max Time": max_time,
182
- "Mean Time": mean_time,
183
- "Std Time": std_time,
184
- "Last Time": last_time,
185
- "Generated Code": generated_code,
186
- "Argument Size": argument_size,
187
- "Output Size": output_size,
188
- "Temporary Size": temp_size,
189
- "FLOPS": self.profiling_data["FLOPS"]
190
- }
191
- iteration_runs = {}
192
- for i in range(len(times_array)):
193
- iteration_runs[f"Run {i}"] = times_array[i]
194
-
195
- with open(md_filename, 'w') as f:
196
- f.write(f"# Reporting for {function}\n")
197
- f.write(f"## Parameters\n")
198
- f.write(
199
- tabulate(param_dict.items(),
200
- headers=["Parameter", "Value"],
201
- tablefmt='github'))
202
- f.write("\n---\n")
203
- f.write(f"## Profiling Data\n")
204
- f.write(
205
- tabulate(profiling_result.items(),
206
- headers=["Parameter", "Value"],
207
- tablefmt='github'))
208
- f.write("\n---\n")
209
- f.write(f"## Iteration Runs\n")
210
- f.write(
211
- tabulate(iteration_runs.items(),
212
- headers=["Iteration", "Time"],
213
- tablefmt='github'))
214
- f.write("\n---\n")
215
- f.write(f"## Compiled Code\n")
216
- f.write(f"```hlo\n")
217
- f.write(self.compiled_code["COMPILED"])
218
- f.write(f"\n```\n")
219
- f.write("\n---\n")
220
- f.write(f"## Lowered Code\n")
221
- f.write(f"```hlo\n")
222
- f.write(self.compiled_code["LOWERED"])
223
- f.write(f"\n```\n")
224
- f.write("\n---\n")
225
- if self.save_jaxpr:
226
- f.write(f"## JAXPR\n")
227
- f.write(f"```haskel\n")
228
- f.write(self.compiled_code["JAXPR"])
140
+ if jax.process_index() == 0:
141
+ min_time = np.min(times_array)
142
+ max_time = np.max(times_array)
143
+ mean_time = np.mean(times_array)
144
+ std_time = np.std(times_array)
145
+ last_time = times_array[-1]
146
+
147
+ flops = self.profiling_data["FLOPS"]
148
+ generated_code = self.profiling_data["generated_code"]
149
+ argument_size = self.profiling_data["argument_size"]
150
+ output_size = self.profiling_data["output_size"]
151
+ temp_size = self.profiling_data["temp_size"]
152
+
153
+ csv_line = (
154
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
155
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
156
+ f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
157
+ )
158
+
159
+ with open(csv_filename, 'a') as f:
160
+ f.write(csv_line)
161
+
162
+ param_dict = {
163
+ "Function": function,
164
+ "Precision": precision,
165
+ "X": x,
166
+ "Y": y,
167
+ "Z": z,
168
+ "PX": px,
169
+ "PY": py,
170
+ "Backend": backend,
171
+ "Nodes": nodes,
172
+ }
173
+ param_dict.update(extra_info)
174
+ profiling_result = {
175
+ "JIT Time": self.jit_time,
176
+ "Min Time": min_time,
177
+ "Max Time": max_time,
178
+ "Mean Time": mean_time,
179
+ "Std Time": std_time,
180
+ "Last Time": last_time,
181
+ "Generated Code": self._normalize_memory_units(generated_code),
182
+ "Argument Size": self._normalize_memory_units(argument_size),
183
+ "Output Size": self._normalize_memory_units(output_size),
184
+ "Temporary Size": self._normalize_memory_units(temp_size),
185
+ "FLOPS": self.profiling_data["FLOPS"]
186
+ }
187
+ iteration_runs = {}
188
+ for i in range(len(times_array)):
189
+ iteration_runs[f"Run {i}"] = times_array[i]
190
+
191
+ with open(md_filename, 'w') as f:
192
+ f.write(f"# Reporting for {function}\n")
193
+ f.write(f"## Parameters\n")
194
+ f.write(
195
+ tabulate(param_dict.items(),
196
+ headers=["Parameter", "Value"],
197
+ tablefmt='github'))
198
+ f.write("\n---\n")
199
+ f.write(f"## Profiling Data\n")
200
+ f.write(
201
+ tabulate(profiling_result.items(),
202
+ headers=["Parameter", "Value"],
203
+ tablefmt='github'))
204
+ f.write("\n---\n")
205
+ f.write(f"## Iteration Runs\n")
206
+ f.write(
207
+ tabulate(iteration_runs.items(),
208
+ headers=["Iteration", "Time"],
209
+ tablefmt='github'))
210
+ f.write("\n---\n")
211
+ f.write(f"## Compiled Code\n")
212
+ f.write(f"```hlo\n")
213
+ f.write(self.compiled_code["COMPILED"])
229
214
  f.write(f"\n```\n")
230
-
215
+ f.write("\n---\n")
216
+ f.write(f"## Lowered Code\n")
217
+ f.write(f"```hlo\n")
218
+ f.write(self.compiled_code["LOWERED"])
219
+ f.write(f"\n```\n")
220
+ f.write("\n---\n")
221
+ if self.save_jaxpr:
222
+ f.write(f"## JAXPR\n")
223
+ f.write(f"```haskel\n")
224
+ f.write(self.compiled_code["JAXPR"])
225
+ f.write(f"\n```\n")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
2
2
  jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
3
  jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
4
  jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=baE5DRsQBYRBiphkceTi4qI_8FPGKQEh73f2pAeS-oc,8208
5
+ jax_hpc_profiler/timer.py,sha256=r4Mw2tC82cxvkMPkIy8BuZjKikgxn6cviEgmu6rpC9o,8616
6
6
  jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.3.dist-info/METADATA,sha256=myC-zD7y_pRb_-tZoSFi0KmglZH8Gk88_-U5RE14Q04,49250
9
- jax_hpc_profiler-0.2.3.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.3.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.3.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.3.dist-info/RECORD,,
7
+ jax_hpc_profiler-0.2.5.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.5.dist-info/METADATA,sha256=6Vk6fA1nz-m8ZZVzSZPsg9xR9iJkFJ01x36pddG0RAM,49250
9
+ jax_hpc_profiler-0.2.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.5.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.5.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.5.dist-info/RECORD,,