jax-hpc-profiler 0.2.3__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -136,23 +136,25 @@ def create_argparser():
136
136
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
137
 
138
138
  args = parser.parse_args()
139
-
140
- if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
141
- print(
142
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
143
- )
144
- args.pdim_strategy = ['plot_all']
145
-
146
- if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
147
- print(
148
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
149
- )
150
- args.pdim_strategy = ['plot_fastest']
151
- if args.plot_times is not None:
152
- args.plot_columns = args.plot_times
153
- elif args.plot_memory is not None:
154
- args.plot_columns = args.plot_memory
155
- else:
156
- raise ValueError('Either plot_times or plot_memory should be provided')
139
+
140
+ # if command was plot, then check if pdim_strategy is validat
141
+ if args.command == 'plot':
142
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
143
+ print(
144
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
145
+ )
146
+ args.pdim_strategy = ['plot_all']
147
+
148
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
149
+ print(
150
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
151
+ )
152
+ args.pdim_strategy = ['plot_fastest']
153
+ if args.plot_times is not None:
154
+ args.plot_columns = args.plot_times
155
+ elif args.plot_memory is not None:
156
+ args.plot_columns = args.plot_memory
157
+ else:
158
+ raise ValueError('Either plot_times or plot_memory should be provided')
157
159
 
158
160
  return args
jax_hpc_profiler/timer.py CHANGED
@@ -28,15 +28,14 @@ class Timer:
28
28
  return None
29
29
  return cost_analysis[0]['flops']
30
30
 
31
- def _normalize_memory_units(self, memory_analysis) -> str:
32
-
33
- sizes_str = ['B', 'KB', 'MB', 'GB', 'TB' , 'PB']
34
- factors = [1 , 1024 , 1024**2 , 1024**3 , 1024**4 , 1024**5]
31
+ def _normalize_memory_units(self, memory_analysis) -> str:
32
+
33
+ sizes_str = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
34
+ factors = [1, 1024, 1024**2, 1024**3, 1024**4, 1024**5]
35
35
  factor = int(np.log10(memory_analysis) // 3)
36
-
36
+
37
37
  return f"{memory_analysis / factors[factor]:.2f} {sizes_str[factor]}"
38
38
 
39
-
40
39
  def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
41
40
  if memory_analysis is None:
42
41
  return None, None, None, None
@@ -68,12 +67,9 @@ class Timer:
68
67
  self.compiled_code["LOWERED"] = lowered.as_text()
69
68
  self.compiled_code["COMPILED"] = compiled.as_text()
70
69
  self.profiling_data["FLOPS"] = cost_analysis
71
- self.profiling_data[
72
- "generated_code"] = memory_analysis[0]
73
- self.profiling_data[
74
- "argument_size"] = memory_analysis[1]
75
- self.profiling_data[
76
- "output_size"] = memory_analysis[2]
70
+ self.profiling_data["generated_code"] = memory_analysis[0]
71
+ self.profiling_data["argument_size"] = memory_analysis[1]
72
+ self.profiling_data["output_size"] = memory_analysis[2]
77
73
  self.profiling_data["temp_size"] = memory_analysis[3]
78
74
  return out
79
75
 
@@ -101,7 +97,7 @@ class Timer:
101
97
  global_times = jax.make_array_from_callback(
102
98
  shape=global_shape,
103
99
  sharding=sharding,
104
- data_callback=lambda _: jnp.expand_dims(times_array,axis=0))
100
+ data_callback=lambda _: jnp.expand_dims(times_array, axis=0))
105
101
 
106
102
  @partial(shard_map,
107
103
  mesh=mesh,
@@ -141,90 +137,90 @@ class Timer:
141
137
  z = x if z is None else z
142
138
 
143
139
  times_array = self._get_mean_times()
144
- min_time = np.min(times_array)
145
- max_time = np.max(times_array)
146
- mean_time = np.mean(times_array)
147
- std_time = np.std(times_array)
148
- last_time = times_array[-1]
149
-
150
-
151
- flops = self.profiling_data["FLOPS"]
152
- generated_code = self.profiling_data["generated_code"]
153
- argument_size = self.profiling_data["argument_size"]
154
- output_size = self.profiling_data["output_size"]
155
- temp_size = self.profiling_data["temp_size"]
156
-
157
- csv_line = (
158
- f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
159
- f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
160
- f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
161
- )
162
-
163
- with open(csv_filename, 'a') as f:
164
- f.write(csv_line)
165
-
166
- param_dict = {
167
- "Function": function,
168
- "Precision": precision,
169
- "X": x,
170
- "Y": y,
171
- "Z": z,
172
- "PX": px,
173
- "PY": py,
174
- "Backend": backend,
175
- "Nodes": nodes,
176
- }
177
- param_dict.update(extra_info)
178
- profiling_result = {
179
- "JIT Time": self.jit_time,
180
- "Min Time": min_time,
181
- "Max Time": max_time,
182
- "Mean Time": mean_time,
183
- "Std Time": std_time,
184
- "Last Time": last_time,
185
- "Generated Code": generated_code,
186
- "Argument Size": argument_size,
187
- "Output Size": output_size,
188
- "Temporary Size": temp_size,
189
- "FLOPS": self.profiling_data["FLOPS"]
190
- }
191
- iteration_runs = {}
192
- for i in range(len(times_array)):
193
- iteration_runs[f"Run {i}"] = times_array[i]
194
-
195
- with open(md_filename, 'w') as f:
196
- f.write(f"# Reporting for {function}\n")
197
- f.write(f"## Parameters\n")
198
- f.write(
199
- tabulate(param_dict.items(),
200
- headers=["Parameter", "Value"],
201
- tablefmt='github'))
202
- f.write("\n---\n")
203
- f.write(f"## Profiling Data\n")
204
- f.write(
205
- tabulate(profiling_result.items(),
206
- headers=["Parameter", "Value"],
207
- tablefmt='github'))
208
- f.write("\n---\n")
209
- f.write(f"## Iteration Runs\n")
210
- f.write(
211
- tabulate(iteration_runs.items(),
212
- headers=["Iteration", "Time"],
213
- tablefmt='github'))
214
- f.write("\n---\n")
215
- f.write(f"## Compiled Code\n")
216
- f.write(f"```hlo\n")
217
- f.write(self.compiled_code["COMPILED"])
218
- f.write(f"\n```\n")
219
- f.write("\n---\n")
220
- f.write(f"## Lowered Code\n")
221
- f.write(f"```hlo\n")
222
- f.write(self.compiled_code["LOWERED"])
223
- f.write(f"\n```\n")
224
- f.write("\n---\n")
225
- if self.save_jaxpr:
226
- f.write(f"## JAXPR\n")
227
- f.write(f"```haskel\n")
228
- f.write(self.compiled_code["JAXPR"])
140
+ if jax.process_index() == 0:
141
+
142
+ min_time = np.min(times_array)
143
+ max_time = np.max(times_array)
144
+ mean_time = np.mean(times_array)
145
+ std_time = np.std(times_array)
146
+ last_time = times_array[-1]
147
+
148
+ flops = self.profiling_data["FLOPS"]
149
+ generated_code = self.profiling_data["generated_code"]
150
+ argument_size = self.profiling_data["argument_size"]
151
+ output_size = self.profiling_data["output_size"]
152
+ temp_size = self.profiling_data["temp_size"]
153
+
154
+ csv_line = (
155
+ f"{function},{precision},{x},{y},{z},{px},{py},{backend},{nodes},"
156
+ f"{self.jit_time:.4f},{min_time:.4f},{max_time:.4f},{mean_time:.4f},{std_time:.4f},{last_time:.4f},"
157
+ f"{generated_code},{argument_size},{output_size},{temp_size},{flops}\n"
158
+ )
159
+
160
+ with open(csv_filename, 'a') as f:
161
+ f.write(csv_line)
162
+
163
+ param_dict = {
164
+ "Function": function,
165
+ "Precision": precision,
166
+ "X": x,
167
+ "Y": y,
168
+ "Z": z,
169
+ "PX": px,
170
+ "PY": py,
171
+ "Backend": backend,
172
+ "Nodes": nodes,
173
+ }
174
+ param_dict.update(extra_info)
175
+ profiling_result = {
176
+ "JIT Time": self.jit_time,
177
+ "Min Time": min_time,
178
+ "Max Time": max_time,
179
+ "Mean Time": mean_time,
180
+ "Std Time": std_time,
181
+ "Last Time": last_time,
182
+ "Generated Code": self._normalize_memory_units(generated_code),
183
+ "Argument Size": self._normalize_memory_units(argument_size),
184
+ "Output Size": self._normalize_memory_units(output_size),
185
+ "Temporary Size": self._normalize_memory_units(temp_size),
186
+ "FLOPS": self.profiling_data["FLOPS"]
187
+ }
188
+ iteration_runs = {}
189
+ for i in range(len(times_array)):
190
+ iteration_runs[f"Run {i}"] = times_array[i]
191
+
192
+ with open(md_filename, 'w') as f:
193
+ f.write(f"# Reporting for {function}\n")
194
+ f.write(f"## Parameters\n")
195
+ f.write(
196
+ tabulate(param_dict.items(),
197
+ headers=["Parameter", "Value"],
198
+ tablefmt='github'))
199
+ f.write("\n---\n")
200
+ f.write(f"## Profiling Data\n")
201
+ f.write(
202
+ tabulate(profiling_result.items(),
203
+ headers=["Parameter", "Value"],
204
+ tablefmt='github'))
205
+ f.write("\n---\n")
206
+ f.write(f"## Iteration Runs\n")
207
+ f.write(
208
+ tabulate(iteration_runs.items(),
209
+ headers=["Iteration", "Time"],
210
+ tablefmt='github'))
211
+ f.write("\n---\n")
212
+ f.write(f"## Compiled Code\n")
213
+ f.write(f"```hlo\n")
214
+ f.write(self.compiled_code["COMPILED"])
229
215
  f.write(f"\n```\n")
230
-
216
+ f.write("\n---\n")
217
+ f.write(f"## Lowered Code\n")
218
+ f.write(f"```hlo\n")
219
+ f.write(self.compiled_code["LOWERED"])
220
+ f.write(f"\n```\n")
221
+ f.write("\n---\n")
222
+ if self.save_jaxpr:
223
+ f.write(f"## JAXPR\n")
224
+ f.write(f"```haskel\n")
225
+ f.write(self.compiled_code["JAXPR"])
226
+ f.write(f"\n```\n")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.3
3
+ Version: 0.2.6
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=dEicamRYqJ6GGdgcph2bwAbmdxPkS4tS12xZ4c0X_Pk,6484
3
+ jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
+ jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
+ jax_hpc_profiler/timer.py,sha256=j6oH5IZz12VJik2cE7EQ3a9tAW9C8xl7D2QLW8Bkz3s,8617
6
+ jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
+ jax_hpc_profiler-0.2.6.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.6.dist-info/METADATA,sha256=AgLXyb89gdxgeyDv02_P5oqIvoffdTh1mc3zFPUVBAU,49250
9
+ jax_hpc_profiler-0.2.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.6.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.6.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.6.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
- jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
- jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=baE5DRsQBYRBiphkceTi4qI_8FPGKQEh73f2pAeS-oc,8208
6
- jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
- jax_hpc_profiler-0.2.3.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.3.dist-info/METADATA,sha256=myC-zD7y_pRb_-tZoSFi0KmglZH8Gk88_-U5RE14Q04,49250
9
- jax_hpc_profiler-0.2.3.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.3.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.3.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.3.dist-info/RECORD,,