jax-hpc-profiler 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/PKG-INFO +1 -1
  2. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/pyproject.toml +1 -1
  3. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/create_argparse.py +20 -18
  4. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/timer.py +1 -0
  5. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/PKG-INFO +1 -1
  6. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/LICENSE +0 -0
  7. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/README.md +0 -0
  8. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/setup.cfg +0 -0
  9. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/__init__.py +0 -0
  10. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/main.py +0 -0
  11. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/plotting.py +0 -0
  12. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler/utils.py +0 -0
  13. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/SOURCES.txt +0 -0
  14. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/dependency_links.txt +0 -0
  15. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/entry_points.txt +0 -0
  16. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/requires.txt +0 -0
  17. {jax_hpc_profiler-0.2.5 → jax_hpc_profiler-0.2.6}/src/jax_hpc_profiler.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "jax_hpc_profiler"
7
- version = "0.2.5"
7
+ version = "0.2.6"
8
8
  description = "HPC Plotter and profiler for benchmarking data made for JAX"
9
9
  authors = [
10
10
  { name="Wassim Kabalan" }
@@ -136,23 +136,25 @@ def create_argparser():
136
136
  default="%m%-%f%-%pn%-%pr%-%b%-%p%-%n%")
137
137
 
138
138
  args = parser.parse_args()
139
-
140
- if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
141
- print(
142
- "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
143
- )
144
- args.pdim_strategy = ['plot_all']
145
-
146
- if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
147
- print(
148
- "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
149
- )
150
- args.pdim_strategy = ['plot_fastest']
151
- if args.plot_times is not None:
152
- args.plot_columns = args.plot_times
153
- elif args.plot_memory is not None:
154
- args.plot_columns = args.plot_memory
155
- else:
156
- raise ValueError('Either plot_times or plot_memory should be provided')
139
+
140
+ # if command was plot, then check if pdim_strategy is validat
141
+ if args.command == 'plot':
142
+ if 'plot_all' in args.pdim_strategy and len(args.pdim_strategy) > 1:
143
+ print(
144
+ "Warning: 'plot_all' strategy is combined with other strategies. Using 'plot_all' only."
145
+ )
146
+ args.pdim_strategy = ['plot_all']
147
+
148
+ if 'plot_fastest' in args.pdim_strategy and len(args.pdim_strategy) > 1:
149
+ print(
150
+ "Warning: 'plot_fastest' strategy is combined with other strategies. Using 'plot_fastest' only."
151
+ )
152
+ args.pdim_strategy = ['plot_fastest']
153
+ if args.plot_times is not None:
154
+ args.plot_columns = args.plot_times
155
+ elif args.plot_memory is not None:
156
+ args.plot_columns = args.plot_memory
157
+ else:
158
+ raise ValueError('Either plot_times or plot_memory should be provided')
157
159
 
158
160
  return args
@@ -138,6 +138,7 @@ class Timer:
138
138
 
139
139
  times_array = self._get_mean_times()
140
140
  if jax.process_index() == 0:
141
+
141
142
  min_time = np.min(times_array)
142
143
  max_time = np.max(times_array)
143
144
  mean_time = np.mean(times_array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE