jax-hpc-profiler 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -28,6 +28,15 @@ class Timer:
28
28
  return None
29
29
  return cost_analysis[0]['flops']
30
30
 
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB' , 'PB']
34
+ factors = [1 , 1024 , 1024**2 , 1024**3 , 1024**4 , 1024**5]
35
+ factor = int(np.log10(memory_analysis) // 3)
36
+
37
+ return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
+
39
+
31
40
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
32
41
  if memory_analysis is None:
33
42
  return None, None, None, None
@@ -38,7 +47,7 @@ class Timer:
38
47
 
39
48
  def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
40
49
  start = time.perf_counter()
41
- out = jax.jit(fun)(*args)
50
+ out = fun(*args)
42
51
  if ndarray_arg is None:
43
52
  out.block_until_ready()
44
53
  else:
@@ -59,10 +68,13 @@ class Timer:
59
68
  self.compiled_code["LOWERED"] = lowered.as_text()
60
69
  self.compiled_code["COMPILED"] = compiled.as_text()
61
70
  self.profiling_data["FLOPS"] = cost_analysis
62
- self.profiling_data["generated_code"] = memory_analysis[0]
63
- self.profiling_data["argument_size"] = memory_analysis[0]
64
- self.profiling_data["output_size"] = memory_analysis[0]
65
- self.profiling_data["temp_size"] = memory_analysis[0]
71
+ self.profiling_data[
72
+ "generated_code"] = memory_analysis[0]
73
+ self.profiling_data[
74
+ "argument_size"] = memory_analysis[1]
75
+ self.profiling_data[
76
+ "output_size"] = memory_analysis[2]
77
+ self.profiling_data["temp_size"] = memory_analysis[3]
66
78
  return out
67
79
 
68
80
  def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
@@ -89,7 +101,7 @@ class Timer:
89
101
  global_times = jax.make_array_from_callback(
90
102
  shape=global_shape,
91
103
  sharding=sharding,
92
- data_callback=lambda x: times_array)
104
+ data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
93
105
 
94
106
  @partial(shard_map,
95
107
  mesh=mesh,
@@ -101,7 +113,7 @@ class Timer:
101
113
 
102
114
  times_array = get_mean_times(global_times)
103
115
  times_array.block_until_ready()
104
- return np.array(times_array.addressable_data(0))
116
+ return np.array(times_array.addressable_data(0)[0])
105
117
 
106
118
  def report(self,
107
119
  csv_filename: str,
@@ -122,8 +134,6 @@ class Timer:
122
134
  csv_filename), os.path.splitext(
123
135
  os.path.basename(csv_filename))[0]
124
136
  report_folder = filename if dirname == "" else f"{dirname}/{filename}"
125
- print(
126
- f"report_folder: {report_folder} csv_filename: {csv_filename}")
127
137
  os.makedirs(report_folder, exist_ok=True)
128
138
  md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
129
139
 
@@ -131,13 +141,13 @@ class Timer:
131
141
  z = x if z is None else z
132
142
 
133
143
  times_array = self._get_mean_times()
134
-
135
144
  min_time = np.min(times_array)
136
145
  max_time = np.max(times_array)
137
146
  mean_time = np.mean(times_array)
138
147
  std_time = np.std(times_array)
139
148
  last_time = times_array[-1]
140
149
 
150
+
141
151
  flops = self.profiling_data["FLOPS"]
142
152
  generated_code = self.profiling_data["generated_code"]
143
153
  argument_size = self.profiling_data["argument_size"]
@@ -178,12 +188,13 @@ class Timer:
178
188
  "Temporary Size": temp_size,
179
189
  "FLOPS": self.profiling_data["FLOPS"]
180
190
  }
191
+ iteration_runs = {}
192
+ for i in range(len(times_array)):
193
+ iteration_runs[f"Run {i}"] = times_array[i]
181
194
 
182
195
  with open(md_filename, 'w') as f:
183
196
  f.write(f"# Reporting for {function}\n")
184
197
  f.write(f"## Parameters\n")
185
- keys = list(param_dict.keys())
186
- values = list(param_dict.values())
187
198
  f.write(
188
199
  tabulate(param_dict.items(),
189
200
  headers=["Parameter", "Value"],
@@ -195,6 +206,12 @@ class Timer:
195
206
  headers=["Parameter", "Value"],
196
207
  tablefmt='github'))
197
208
  f.write("\n---\n")
209
+ f.write(f"## Iteration Runs\n")
210
+ f.write(
211
+ tabulate(iteration_runs.items(),
212
+ headers=["Iteration", "Time"],
213
+ tablefmt='github'))
214
+ f.write("\n---\n")
198
215
  f.write(f"## Compiled Code\n")
199
216
  f.write(f"```hlo\n")
200
217
  f.write(self.compiled_code["COMPILED"])
@@ -210,3 +227,4 @@ class Timer:
210
227
  f.write(f"```haskel\n")
211
228
  f.write(self.compiled_code["JAXPR"])
212
229
  f.write(f"\n```\n")
230
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
2
2
  jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
3
  jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
4
  jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=4XGKuP2fclGfac2sNz_W8aOamFw7TfiT2Nvp6BarMJk,7621
5
+ jax_hpc_profiler/timer.py,sha256=baE5DRsQBYRBiphkceTi4qI_8FPGKQEh73f2pAeS-oc,8208
6
6
  jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.1.dist-info/METADATA,sha256=smuVIDzcbI2aH4pip8Rnh0qsTNjsnVkP8kvCWA1WTWw,49250
9
- jax_hpc_profiler-0.2.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.1.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.1.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.1.dist-info/RECORD,,
7
+ jax_hpc_profiler-0.2.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.3.dist-info/METADATA,sha256=myC-zD7y_pRb_-tZoSFi0KmglZH8Gk88_-U5RE14Q04,49250
9
+ jax_hpc_profiler-0.2.3.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.3.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.3.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.3.dist-info/RECORD,,