jax-hpc-profiler 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/pyproject.toml +1 -1
  3. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/timer.py +15 -7
  4. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  5. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/LICENSE +0 -0
  6. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/README.md +0 -0
  7. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/setup.cfg +0 -0
  8. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/__init__.py +0 -0
  9. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/create_argparse.py +0 -0
  10. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/main.py +0 -0
  11. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/plotting.py +0 -0
  12. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler/utils.py +0 -0
  13. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.1 → jax_hpc_profiler-0.2.2}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.1"
7
+ version = "0.2.2"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -59,9 +59,12 @@ class Timer:
59
59
  self.compiled_code["LOWERED"] = lowered.as_text()
60
60
  self.compiled_code["COMPILED"] = compiled.as_text()
61
61
  self.profiling_data["FLOPS"] = cost_analysis
62
- self.profiling_data["generated_code"] = memory_analysis[0]
63
- self.profiling_data["argument_size"] = memory_analysis[0]
64
- self.profiling_data["output_size"] = memory_analysis[0]
62
+ self.profiling_data[
63
+ "generated_code"] = memory_analysis[0]
64
+ self.profiling_data[
65
+ "argument_size"] = memory_analysis[0]
66
+ self.profiling_data[
67
+ "output_size"] = memory_analysis[0]
65
68
  self.profiling_data["temp_size"] = memory_analysis[0]
66
69
  return out
67
70
 
@@ -122,8 +125,6 @@ class Timer:
122
125
  csv_filename), os.path.splitext(
123
126
  os.path.basename(csv_filename))[0]
124
127
  report_folder = filename if dirname == "" else f"{dirname}/{filename}"
125
- print(
126
- f"report_folder: {report_folder} csv_filename: {csv_filename}")
127
128
  os.makedirs(report_folder, exist_ok=True)
128
129
  md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
129
130
 
@@ -178,12 +179,13 @@ class Timer:
178
179
  "Temporary Size": temp_size,
179
180
  "FLOPS": self.profiling_data["FLOPS"]
180
181
  }
182
+ iteration_runs = {}
183
+ for i in range(len(times_array)):
184
+ iteration_runs[f"Run {i}"] = times_array[i]
181
185
 
182
186
  with open(md_filename, 'w') as f:
183
187
  f.write(f"# Reporting for {function}\n")
184
188
  f.write(f"## Parameters\n")
185
- keys = list(param_dict.keys())
186
- values = list(param_dict.values())
187
189
  f.write(
188
190
  tabulate(param_dict.items(),
189
191
  headers=["Parameter", "Value"],
@@ -195,6 +197,12 @@ class Timer:
195
197
  headers=["Parameter", "Value"],
196
198
  tablefmt='github'))
197
199
  f.write("\n---\n")
200
+ f.write(f"## Iteration Runs\n")
201
+ f.write(
202
+ tabulate(iteration_runs.items(),
203
+ headers=["Iteration", "Time"],
204
+ tablefmt='github'))
205
+ f.write("\n---\n")
198
206
  f.write(f"## Compiled Code\n")
199
207
  f.write(f"```hlo\n")
200
208
  f.write(self.compiled_code["COMPILED"])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE