jax-hpc-profiler 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -89,9 +89,9 @@ def create_argparser():
89
89
  ],
90
90
  help='Memory columns to plot')
91
91
  plot_parser.add_argument('-mu',
92
- '--memory_units',
93
- default='GB',
94
- help='Memory units to plot (KB, MB, GB, TB)')
92
+ '--memory_units',
93
+ default='GB',
94
+ help='Memory units to plot (KB, MB, GB, TB)')
95
95
 
96
96
  # Plot customization arguments
97
97
  plot_parser.add_argument('-fs',
jax_hpc_profiler/main.py CHANGED
@@ -15,7 +15,7 @@ def main():
15
15
  dataframes, available_gpu_counts, available_data_sizes = clean_up_csv(
16
16
  args.csv_files, args.precision, args.function_name, args.gpus,
17
17
  args.data_size, args.filter_pdims, args.pdim_strategy,
18
- args.backends,args.memory_units)
18
+ args.backends, args.memory_units)
19
19
  if len(dataframes) == 0:
20
20
  print(f"No dataframes found for the given arguments. Exiting...")
21
21
  sys.exit(1)
@@ -29,12 +29,10 @@ def main():
29
29
  if data_size in available_data_sizes
30
30
  ]
31
31
  if len(args.gpus) == 0:
32
- print(
33
- f"No dataframes found for the given GPUs. Exiting...")
32
+ print(f"No dataframes found for the given GPUs. Exiting...")
34
33
  sys.exit(1)
35
34
  if len(args.data_size) == 0:
36
- print(
37
- f"No dataframes found for the given data sizes. Exiting...")
35
+ print(f"No dataframes found for the given data sizes. Exiting...")
38
36
  sys.exit(1)
39
37
 
40
38
  if args.scaling == 'Weak':
jax_hpc_profiler/timer.py CHANGED
@@ -1,12 +1,13 @@
1
1
  import os
2
2
  import time
3
3
  from functools import partial
4
- from typing import Any, Callable, List
4
+ from typing import Any, Callable, List, Tuple
5
5
 
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
9
  from jax import make_jaxpr
10
+ from jax.experimental import mesh_utils
10
11
  from jax.experimental.shard_map import shard_map
11
12
  from jax.sharding import Mesh, NamedSharding
12
13
  from jax.sharding import PartitionSpec as P
@@ -22,6 +23,19 @@ class Timer:
22
23
  self.compiled_code = {}
23
24
  self.save_jaxpr = save_jaxpr
24
25
 
26
+ def _read_cost_analysis(self, cost_analysis: Any) -> str | None:
27
+ if cost_analysis is None:
28
+ return None
29
+ return cost_analysis[0]['flops']
30
+
31
+ def _read_memory_analysis(self, memory_analysis: Any) -> Tuple:
32
+ if memory_analysis is None:
33
+ return None, None, None, None
34
+ return (memory_analysis.generated_code_size_in_bytes,
35
+ memory_analysis.argument_size_in_bytes,
36
+ memory_analysis.output_size_in_bytes,
37
+ memory_analysis.temp_size_in_bytes)
38
+
25
39
  def chrono_jit(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
26
40
  start = time.perf_counter()
27
41
  out = jax.jit(fun)(*args)
@@ -38,18 +52,17 @@ class Timer:
38
52
 
39
53
  lowered = jax.jit(fun).lower(*args)
40
54
  compiled = lowered.compile()
41
- memory_analysis = compiled.memory_analysis()
55
+ memory_analysis = self._read_memory_analysis(
56
+ compiled.memory_analysis())
57
+ cost_analysis = self._read_cost_analysis(compiled.cost_analysis())
58
+
42
59
  self.compiled_code["LOWERED"] = lowered.as_text()
43
60
  self.compiled_code["COMPILED"] = compiled.as_text()
44
- self.profiling_data["FLOPS"] = compiled.cost_analysis()[0]['flops']
45
- self.profiling_data[
46
- "generated_code"] = memory_analysis.generated_code_size_in_bytes
47
- self.profiling_data[
48
- "argument_size"] = memory_analysis.argument_size_in_bytes
49
- self.profiling_data[
50
- "output_size"] = memory_analysis.output_size_in_bytes
51
- self.profiling_data["temp_size"] = memory_analysis.temp_size_in_bytes
52
-
61
+ self.profiling_data["FLOPS"] = cost_analysis
62
+ self.profiling_data["generated_code"] = memory_analysis[0]
63
+ self.profiling_data["argument_size"] = memory_analysis[0]
64
+ self.profiling_data["output_size"] = memory_analysis[0]
65
+ self.profiling_data["temp_size"] = memory_analysis[0]
53
66
  return out
54
67
 
55
68
  def chrono_fun(self, fun: Callable, *args, ndarray_arg=None) -> np.ndarray:
@@ -63,56 +76,62 @@ class Timer:
63
76
  self.times.append((end - start) * 1e3)
64
77
  return out
65
78
 
66
- def _get_mean_times(self, times_array: jnp.ndarray,
67
- sharding: NamedSharding):
68
- mesh = sharding.mesh
69
- specs = sharding.spec
70
- valid_letters = [letter for letter in specs if letter is not None]
71
- assert len(valid_letters
72
- ) > 0, "Sharding was provided but with no partition specs"
79
+ def _get_mean_times(self) -> np.ndarray:
80
+ if jax.device_count() == 1:
81
+ return np.array(self.times)
82
+
83
+ devices = mesh_utils.create_device_mesh((jax.device_count(), ))
84
+ mesh = Mesh(devices, ('x', ))
85
+ sharding = NamedSharding(mesh, P('x'))
86
+
87
+ times_array = jnp.array(self.times)
88
+ global_shape = (jax.device_count(), times_array.shape[0])
89
+ global_times = jax.make_array_from_callback(
90
+ shape=global_shape,
91
+ sharding=sharding,
92
+ data_callback=lambda x: times_array)
73
93
 
74
94
  @partial(shard_map,
75
95
  mesh=mesh,
76
- in_specs=specs,
96
+ in_specs=P('x'),
77
97
  out_specs=P(),
78
98
  check_rep=False)
79
99
  def get_mean_times(times):
80
- mean = jax.lax.pmean(times, axis_name=valid_letters[0])
81
- for axis_name in valid_letters[1:]:
82
- mean = jax.lax.pmean(mean, axis_name=axis_name)
83
- return mean
100
+ return jax.lax.pmean(times, axis_name='x')
84
101
 
85
- times_array = get_mean_times(times_array)
102
+ times_array = get_mean_times(global_times)
86
103
  times_array.block_until_ready()
87
- return times_array
104
+ return np.array(times_array.addressable_data(0))
88
105
 
89
106
  def report(self,
90
107
  csv_filename: str,
91
108
  function: str,
92
- precision: str,
93
109
  x: int,
94
- y: int,
95
- z: int,
96
- px: int,
97
- py: int,
98
- backend: str,
99
- nodes: int,
100
- sharding: NamedSharding | None = None,
110
+ y: int | None = None,
111
+ z: int | None = None,
112
+ precision: str = "float32",
113
+ px: int = 1,
114
+ py: int = 1,
115
+ backend: str = "NCCL",
116
+ nodes: int = 1,
101
117
  md_filename: str | None = None,
102
118
  extra_info: dict = {}):
103
- times_array = jnp.array(self.times)
104
119
 
105
120
  if md_filename is None:
106
- dirname, filename = os.path.dirname(csv_filename), os.path.splitext(os.path.basename(csv_filename))[0]
121
+ dirname, filename = os.path.dirname(
122
+ csv_filename), os.path.splitext(
123
+ os.path.basename(csv_filename))[0]
107
124
  report_folder = filename if dirname == "" else f"{dirname}/{filename}"
108
- print(f"report_folder: {report_folder} csv_filename: {csv_filename}")
125
+ print(
126
+ f"report_folder: {report_folder} csv_filename: {csv_filename}")
109
127
  os.makedirs(report_folder, exist_ok=True)
110
128
  md_filename = f"{report_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md"
111
129
 
112
- if sharding is not None:
113
- times_array = self._get_mean_times(times_array, sharding)
130
+ y = x if y is None else y
131
+ z = x if z is None else z
132
+
133
+ times_array = self._get_mean_times()
114
134
 
115
- times_array = np.array(times_array)
116
135
  min_time = np.min(times_array)
117
136
  max_time = np.max(times_array)
118
137
  mean_time = np.mean(times_array)
@@ -163,10 +182,18 @@ class Timer:
163
182
  with open(md_filename, 'w') as f:
164
183
  f.write(f"# Reporting for {function}\n")
165
184
  f.write(f"## Parameters\n")
166
- f.write(tabulate(param_dict.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
185
+ keys = list(param_dict.keys())
186
+ values = list(param_dict.values())
187
+ f.write(
188
+ tabulate(param_dict.items(),
189
+ headers=["Parameter", "Value"],
190
+ tablefmt='github'))
167
191
  f.write("\n---\n")
168
192
  f.write(f"## Profiling Data\n")
169
- f.write(tabulate(profiling_result.items() , headers=["Parameter" , "Value"] , tablefmt='github'))
193
+ f.write(
194
+ tabulate(profiling_result.items(),
195
+ headers=["Parameter", "Value"],
196
+ tablefmt='github'))
170
197
  f.write("\n---\n")
171
198
  f.write(f"## Compiled Code\n")
172
199
  f.write(f"```hlo\n")
jax_hpc_profiler/utils.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
6
  import pandas as pd
7
7
  from matplotlib.axes import Axes
8
8
 
9
+
9
10
  def inspect_data(dataframes: Dict[str, pd.DataFrame]):
10
11
  """
11
12
  Inspect the dataframes.
@@ -203,17 +204,17 @@ def concatenate_csvs(root_dir: str, output_dir: str):
203
204
  if file.endswith('.csv'):
204
205
  csv_file_path = os.path.join(root, file)
205
206
  print(f'Concatenating {csv_file_path}...')
206
- df = pd.read_csv(
207
- csv_file_path,
208
- header=None,
209
- names=[
210
- "function", "precision", "x", "y", "z", "px",
211
- "py", "backend", "nodes", "jit_time", "min_time",
212
- "max_time", "mean_time", "std_div", "last_time",
213
- "generated_code", "argument_size", "output_size",
214
- "temp_size", "flops"
215
- ],
216
- index_col=False)
207
+ df = pd.read_csv(csv_file_path,
208
+ header=None,
209
+ names=[
210
+ "function", "precision", "x", "y",
211
+ "z", "px", "py", "backend", "nodes",
212
+ "jit_time", "min_time", "max_time",
213
+ "mean_time", "std_div", "last_time",
214
+ "generated_code", "argument_size",
215
+ "output_size", "temp_size", "flops"
216
+ ],
217
+ index_col=False)
217
218
  if file not in combined_dfs:
218
219
  combined_dfs[file] = df
219
220
  else:
@@ -340,23 +341,23 @@ def clean_up_csv(
340
341
  if pdims:
341
342
  px_list, py_list = zip(*[map(int, p.split('x')) for p in pdims])
342
343
  df = df[(df['px'].isin(px_list)) & (df['py'].isin(py_list))]
343
-
344
- # convert memory units columns to remquested memory_units
344
+
345
+ # convert memory units columns to remquested memory_units
345
346
  match memory_units:
346
- case 'KB':
347
- factor = 1024
348
- case 'MB':
349
- factor = 1024**2
350
- case 'GB':
351
- factor = 1024**3
352
- case 'TB':
353
- factor = 1024**4
354
- case _:
355
- factor = 1
356
-
347
+ case 'KB':
348
+ factor = 1024
349
+ case 'MB':
350
+ factor = 1024**2
351
+ case 'GB':
352
+ factor = 1024**3
353
+ case 'TB':
354
+ factor = 1024**4
355
+ case _:
356
+ factor = 1
357
+
357
358
  df['generated_code'] = df['generated_code'] / factor
358
- df['argument_size'] = df['argument_size'] / factor
359
- df['output_size'] = df['output_size'] / factor
359
+ df['argument_size'] = df['argument_size'] / factor
360
+ df['output_size'] = df['output_size'] / factor
360
361
  df['temp_size'] = df['temp_size'] / factor
361
362
  # in case of the same test is run multiple times, keep the last one
362
363
  df = df.drop_duplicates(subset=[
@@ -383,7 +384,7 @@ def clean_up_csv(
383
384
  df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
384
385
  df.drop(columns=['px', 'py'], inplace=True)
385
386
  if not 'plot_all' in pdims_strategy:
386
- df = df[df['decomp'].isin(pdims_strategy)]
387
+ df = df[df['decomp'].isin(pdims_strategy)]
387
388
  # check available gpus in dataset
388
389
  available_gpu_counts.update(df['gpus'].unique())
389
390
  available_data_sizes.update(df['x'].unique())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -678,6 +678,19 @@ License: GNU GENERAL PUBLIC LICENSE
678
678
  Public License instead of this License. But first, please read
679
679
  <https://www.gnu.org/licenses/why-not-lgpl.html>.
680
680
 
681
+ Project-URL: Homepage, https://github.com/ASKabalan/jax-hpc-profiler
682
+ Keywords: jax,hpc,profiler,plotter,benchmarking
683
+ Classifier: Development Status :: 4 - Beta
684
+ Classifier: Intended Audience :: Developers
685
+ Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
686
+ Classifier: Programming Language :: Python :: 3
687
+ Classifier: Programming Language :: Python :: 3.8
688
+ Classifier: Programming Language :: Python :: 3.9
689
+ Classifier: Programming Language :: Python :: 3.10
690
+ Classifier: Programming Language :: Python :: 3.11
691
+ Classifier: Programming Language :: Python :: 3.12
692
+ Classifier: Programming Language :: Python :: 3 :: Only
693
+ Requires-Python: >=3.8
681
694
  Description-Content-Type: text/markdown
682
695
  License-File: LICENSE
683
696
  Requires-Dist: numpy
@@ -686,9 +699,11 @@ Requires-Dist: matplotlib
686
699
  Requires-Dist: seaborn
687
700
  Requires-Dist: tabulate
688
701
 
689
- # HPC Plotter
702
+ Here's the updated README with the additional information about the timer.report and the multi-GPU setup:
690
703
 
691
- HPC Plotter is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
704
+ # JAX HPC Profiler
705
+
706
+ JAX HPC Profiler is a tool designed for benchmarking and visualizing performance data in high-performance computing (HPC) environments. It provides functionalities to generate, concatenate, and plot CSV data from various runs.
692
707
 
693
708
  ## Table of Contents
694
709
  - [Introduction](#introduction)
@@ -697,9 +712,10 @@ HPC Plotter is a tool designed for benchmarking and visualizing performance data
697
712
  - [CSV Structure](#csv-structure)
698
713
  - [Concatenating Files from Different Runs](#concatenating-files-from-different-runs)
699
714
  - [Plotting CSV Data](#plotting-csv-data)
715
+ - [Examples](#examples)
700
716
 
701
717
  ## Introduction
702
- HPC Plotter allows users to:
718
+ JAX HPC Profiler allows users to:
703
719
  1. Generate CSV files containing performance data.
704
720
  2. Concatenate multiple CSV files from different runs.
705
721
  3. Plot the performance data for analysis.
@@ -709,53 +725,80 @@ HPC Plotter allows users to:
709
725
  To install the package, run the following command:
710
726
 
711
727
  ```bash
712
- pip install hpc-plotter
728
+ pip install jax-hpc-profiler
713
729
  ```
730
+
714
731
  ## Generating CSV Files Using the Timer Class
715
732
 
716
- To generate CSV files, you can use the `Timer` class provided in the `hpc_plotter.timer` module. This class helps in timing functions and saving the timing results to CSV files.
733
+ To generate CSV files, you can use the `Timer` class provided in the `jax_hpc_profiler.timer` module. This class helps in timing functions and saving the timing results to CSV files.
717
734
 
718
735
  ### Example Usage
719
736
 
720
737
  ```python
721
- import time
722
- from hpc_plotter.timer import Timer
723
738
  import jax
724
- # Define the functions you want to time
725
- def example_function():
726
- time.sleep(1) # Simulating a task
727
-
728
- # Create a Timer instance
729
- timer = Timer()
730
-
731
- # Time the function
732
- timer.chrono_jit(example_function)
733
- for _ in range(5):
734
- timer.chrono_fun(example_function)
735
-
736
- # Metadata for the CSV file
737
- metadata = {
738
- 'rank': jax.process_index(),
739
- 'function_name': 'example_function',
740
- 'precision': 'float32',
741
- 'x': '1024',
742
- 'y': '1024',
743
- 'z': '1024',
744
- 'px': '4',
745
- 'py': '4',
746
- 'backend': 'NCCL',
747
- 'nodes': '2'
739
+ from jax_hpc_profiler import Timer
740
+
741
+ def fcn(m, n, k):
742
+ return jax.numpy.dot(m, n) + k
743
+
744
+ timer = Timer(save_jaxpr=True)
745
+ m = jax.numpy.ones((1000, 1000))
746
+ n = jax.numpy.ones((1000, 1000))
747
+ k = jax.numpy.ones((1000, 1000))
748
+
749
+ timer.chrono_jit(fcn, m, n, k)
750
+ for i in range(10):
751
+ timer.chrono_fun(fcn, m, n, k)
752
+
753
+ meta_data = {
754
+ "function": "fcn",
755
+ "precision": "float32",
756
+ "x": 1000,
757
+ "y": 1000,
758
+ "z": 1000,
759
+ "px": 1,
760
+ "py": 1,
761
+ "backend": "NCCL",
762
+ "nodes": 1
763
+ }
764
+ extra_info = {
765
+ "done": "yes"
748
766
  }
749
767
 
750
- # Print the results to a CSV file
751
- timer.print_to_csv('output.csv', **metadata)
768
+ timer.report("examples/profiling/test.csv", **meta_data, extra_info=extra_info)
752
769
  ```
753
770
 
771
+ `timer.report` has sensible defaults and this is the API for the `Timer` class:
772
+
773
+ - `csv_filename`: The path to the CSV file to save the timing data **(required)**.
774
+ - `function`: The name of the function being timed **(required)**.
775
+ - `x`: The size of the input data in the x dimension **(required)**.
776
+ - `y`: The size of the input data in the y dimension (by default same as x).
777
+ - `z`: The size of the input data in the z dimension (by default same as x).
778
+ - `precision`: The precision of the data (default: "float32").
779
+ - `px`: The number of partitions in the x dimension (default: 1).
780
+ - `py`: The number of partitions in the y dimension (default: 1).
781
+ - `backend`: The backend used for computation (default: "NCCL").
782
+ - `nodes`: The number of nodes used for computation (default: 1).
783
+ - `md_filename`: The path to the markdown file containing the compiled code and other information (default: {csv_folder}/{x}_{px}_{py}_{backend}_{precision}_{function}.md).
784
+ - `extra_info`: Additional information to include in the report (default: {}
785
+
786
+ `px` and `py` are used to specify the data decomposition. For example, if you have a 2D array of size 1000x1000 and you partition it into 4 parts (2x2), you would set `px=2` and `py=2`.\
787
+ they can also be used in a single device run to specify batch size.
788
+
789
+ Some decomposition parameters are generated and that are specific to 3D data decomposition.\
790
+ `slab_yz` if the distributed axis is the y-axis.\
791
+ `slab_xy` if the distributed axis is the x-axis.\
792
+ `pencils` if the distributed axis are the x and y axes.
793
+
794
+ ### Multi-GPU Setup
795
+
796
+ In a multi-GPU setup, the times are automatically averaged across ranks, providing a single performance metric for the entire setup.
797
+
754
798
  ## CSV Structure
755
799
 
756
800
  The CSV files should follow a specific structure to ensure proper processing and concatenation. The directory structure should be organized by GPU type, with subdirectories for the number of GPUs and the respective CSV files.
757
801
 
758
-
759
802
  ### Example Directory Structure
760
803
 
761
804
  ```
@@ -790,12 +833,12 @@ root_directory/
790
833
 
791
834
  ## Concatenating Files from Different Runs
792
835
 
793
- The `plot` function expects the directory to be organized as described above, but with the different number of GPUs toghether in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
836
+ The `plot` function expects the directory to be organized as described above, but with the different number of GPUs together in the same directory. The `concatenate` function can be used to concatenate the CSV files from different runs into a single file.
794
837
 
795
838
  ### Example Usage
796
839
 
797
840
  ```bash
798
- hpc-plotter concat /path/to/root_directory /path/to/output
841
+ jax-hpc-profiler concat /path/to/root_directory /path/to/output
799
842
  ```
800
843
 
801
844
  And the output will be:
@@ -812,8 +855,6 @@ out_directory/
812
855
  └── method_3.csv
813
856
  ```
814
857
 
815
-
816
-
817
858
  ## Plotting CSV Data
818
859
 
819
860
  You can plot the performance data using the `plot` command. The plotting command provides various options to customize the plots.
@@ -821,27 +862,41 @@ You can plot the performance data using the `plot` command. The plotting command
821
862
  ### Usage
822
863
 
823
864
  ```bash
824
- hpc-plotter plot -f <csv_files> [options]
865
+ jax-hpc-profiler plot -f <csv_files> [options]
825
866
  ```
826
867
 
827
- with options :
828
-
868
+ ### Options
829
869
 
830
870
  - `-f, --csv_files`: List of CSV files to plot (required).
831
- - `-g, --gpus`: Filter GPUs. List of number of GPUs to plot.
832
- - `-d, --data_size`: Filter data sizes. List of data sizes to plot.
871
+ - `-g, --gpus`: List of number of GPUs to plot.
872
+ - `-d, --data_size`: List of data sizes to plot.
833
873
  - `-fd, --filter_pdims`: List of pdims to filter (e.g., 1x4 2x2 4x8).
834
- - `-ps, --pdims_strategy`: Strategy for plotting pdims (`plot_all` or `plot_fastest`).
835
- - `plot_all`: Plot every decomposition. 1xX and Xx1 as slabs, XxX as pencils.
874
+ - `-ps, --pdim_strategy`: Strategy for plotting pdims. This argument can be multiple ones (`plot_all`, `plot_fastest`, `slab_yz`, `slab_xy`, `pencils`).
875
+ - `plot_all`: Plot every decomposition.
836
876
  - `plot_fastest`: Plot the fastest decomposition.
837
- - `-p, --precision`: Precision to filter by (`float32` or `float64`).
838
- - `-fn, --function_name`: Function name to filter.
839
- - `-ta, --time_aggregation`: Time aggregation method (`mean`, `min`, `max`).
840
- - `-tc, --time_column`: Time column to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_div`, `last_time`).
877
+ - `-pr, --precision`: Precision to filter by. This argument can be multiple ones (`float32`, `float64`).
878
+ - `-fn, --function_name`: Function names to filter. This argument can be multiple ones.
879
+ - `-pt, --plot_times`: Time columns to plot (`jit_time`, `min_time`, `max_time`, `mean_time`, `std_time`, `last_time`). Note: You cannot plot memory and time together.
880
+ - `-pm, --plot_memory`: Memory columns to plot (`generated_code`, `argument_size`, `output_size`, `temp_size`). Note: You cannot plot memory and time together.
881
+ - `-mu, --memory_units`: Memory units to plot (`KB`, `MB`, `GB`, `TB`).
841
882
  - `-fs, --figure_size`: Figure size.
842
- - `-nl, --nodes_in_label`: Use node names in labels.
843
883
  - `-o, --output`: Output file (if none then only show plot).
844
884
  - `-db, --dark_bg`: Use dark background for plotting.
845
- - `-pd, --print_decompositions`: Print decompositions on plot (only for `plot_fastest`).
846
- - `-b, --backends`: List of backends to include.
847
- - `-sc, --scaling`: Scaling type (`Weak` or `Strong`).
885
+ - `-pd, --print_decompositions`: Print decompositions on plot (experimental).
886
+ - `-b, --backends`: List of backends to include. This argument can be multiple ones.
887
+ - `-sc, --scaling`: Scaling type (`Weak`, `Strong`).
888
+ - `-l, --label_text`: Custom label for the plot. You can use placeholders: `%decomposition%` (or `%p%`), `%precision%` (or `%pr%`), `%plot_name%` (or `%pn%`), `%backend%` (or `%b%`), `%node%` (or `%n%`), `%methodname%` (or `%m%`).
889
+
890
+ ## Examples
891
+
892
+ The repository includes examples for both profiling and plotting.
893
+
894
+ ### Profiling Example
895
+
896
+ See the `examples/profiling` directory for profiling examples, including `function.py`, `test.csv`, and the generated markdown report.
897
+
898
+ ### Plotting Example
899
+
900
+ See the `examples/plotting` directory for plotting examples, including `generator.py`, `sample_data1.csv`, `sample_data2.csv`, and `sample_data3.csv`.
901
+
902
+ a multi GPU example comparing distributed FFT can be found here [jaxdecomp-bechmarks](https://github.com/ASKabalan/jaxdecomp-benchmarks)
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
+ jax_hpc_profiler/create_argparse.py,sha256=sY3OKe6lMrXtVnKyx-EtREXLy9L1TK_mdf0WYRQXu5A,6351
3
+ jax_hpc_profiler/main.py,sha256=CKsKVUKsMRatlYfrFLOV1WZ582rZPtofV89sY_2tpQI,2370
4
+ jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
+ jax_hpc_profiler/timer.py,sha256=4XGKuP2fclGfac2sNz_W8aOamFw7TfiT2Nvp6BarMJk,7621
6
+ jax_hpc_profiler/utils.py,sha256=GqlGaD-Zf9GmtuQRJAF6S6QD4E2S8QtZ-jGjjC5ZFwU,14133
7
+ jax_hpc_profiler-0.2.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.1.dist-info/METADATA,sha256=smuVIDzcbI2aH4pip8Rnh0qsTNjsnVkP8kvCWA1WTWw,49250
9
+ jax_hpc_profiler-0.2.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
+ jax_hpc_profiler-0.2.1.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.1.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.1.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=JHCbfU6ChNyTMxLjqf0DOCAScRHEV6K0ZC6MHJ-9ofc,6336
3
- jax_hpc_profiler/main.py,sha256=uOWduNhn8guNMTU5zpkG2QMGirX_jrSvyzK4tvwYI2k,2403
4
- jax_hpc_profiler/plotting.py,sha256=iA-CCBsDK5wM872tmpS6qo4ws9F6MZLutfkdAN_sWlw,8326
5
- jax_hpc_profiler/timer.py,sha256=YO6KV3k1jJMgPjYYIaPi1aHAYoIYy1J7Qu76Vm2eUyk,6770
6
- jax_hpc_profiler/utils.py,sha256=itGSe15pS0Qi07nZe2PUXZ95ZampbMmUbrTsgB8g2zU,13990
7
- jax_hpc_profiler-0.2.0.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.0.dist-info/METADATA,sha256=kA-5n1LLeqXAKBt3qLQLEYqKu7-OghOUK3raso533bk,45786
9
- jax_hpc_profiler-0.2.0.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
10
- jax_hpc_profiler-0.2.0.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.0.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.0.dist-info/RECORD,,