jax-hpc-profiler 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/timer.py CHANGED
@@ -28,6 +28,15 @@ class Timer:
28
28
  return None
29
29
  return cost_analysis[0]['flops']
30
30
 
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB' , 'PB']
34
+ factors = [1 , 1024 , 1024**2 , 1024**3 , 1024**4 , 1024**5]
35
+ factor = int(np.log10(memory_analysis) // 3)
36
+
37
+ return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
+
39
+
31
40
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
32
41
  if memory_analysis is None:
33
42
  return None, None, None, None
@@ -38,7 +47,7 @@ class Timer:
38
47
 
39
48
  def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
40
49
  start = time.perf_counter()
41
- out = jax.jit(fun)(*args)
50
+ out = fun(*args)
42
51
  if ndarray_arg is None:
43
52
  out.block_until_ready()
44
53
  else:
@@ -62,10 +71,10 @@ class Timer:
62
71
  self.profiling_data[
63
72
  "generated_code"] = memory_analysis[0]
64
73
  self.profiling_data[
65
- "argument_size"] = memory_analysis[0]
74
+ "argument_size"] = memory_analysis[1]
66
75
  self.profiling_data[
67
- "output_size"] = memory_analysis[0]
68
- self.profiling_data["temp_size"] = memory_analysis[0]
76
+ "output_size"] = memory_analysis[2]
77
+ self.profiling_data["temp_size"] = memory_analysis[3]
69
78
  return out
70
79
 
71
80
  def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
@@ -92,7 +101,7 @@ class Timer:
92
101
  global_times = jax.make_array_from_callback(
93
102
  shape=global_shape,
94
103
  sharding=sharding,
95
- data_callback=lambda x: times_array)
104
+ data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
96
105
 
97
106
  @partial(shard_map,
98
107
  mesh=mesh,
@@ -104,7 +113,7 @@ class Timer:
104
113
 
105
114
  times_array = get_mean_times(global_times)
106
115
  times_array.block_until_ready()
107
- return np.array(times_array.addressable_data(0))
116
+ return np.array(times_array.addressable_data(0)[0])
108
117
 
109
118
  def report(self,
110
119
  csv_filename: str,
@@ -132,13 +141,13 @@ class Timer:
132
141
  z = x if z is None else z
133
142
 
134
143
  times_array = self._get_mean_times()
135
-
136
144
  min_time = np.min(times_array)
137
145
  max_time = np.max(times_array)
138
146
  mean_time = np.mean(times_array)
139
147
  std_time = np.std(times_array)
140
148
  last_time = times_array[-1]
141
149
 
150
+
142
151
  flops = self.profiling_data["FLOPS"]
143
152
  generated_code = self.profiling_data["generated_code"]
144
153
  argument_size = self.profiling_data["argument_size"]
@@ -218,3 +227,4 @@ class Timer:
218
227
  f.write(f"```haskel\n")
219
228
  f.write(self.compiled_code["JAXPR"])
220
229
  f.write(f"\n```\n")
230
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -2,11 +2,11 @@ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,
2
2
  jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
3
  jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
4
  jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=BhwWrdPIdZHb2KCpyWYXXNQyBxl4QU3XEmfft5FA1Vc,7843
5
+ jax_hpc_profiler/timer.py,sha256=baE5DRsQBYRBiphkceTi4qI_8FPGKQEh73f2pAeS-oc,8208
6
6
  jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.2.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.2.dist-info/METADATA,sha256=jrZBe0YAZIRuAcgSJryVccu5Rgk7z5G6UwplaEBPNb0,49250
9
- jax_hpc_profiler-0.2.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.2.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.2.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.2.dist-info/RECORD,,
7
+ jax_hpc_profiler-0.2.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.3.dist-info/METADATA,sha256=myC-zD7y_pRb_-tZoSFi0KmglZH8Gk88_-U5RE14Q04,49250
9
+ jax_hpc_profiler-0.2.3.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.3.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.3.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.3.dist-info/RECORD,,