jax-hpc-profiler 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jax_hpc_profiler/utils.py CHANGED
@@ -1,13 +1,11 @@
1
1
  import os
2
2
  from typing import Dict, List, Optional, Tuple
3
3
 
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
4
  import pandas as pd
7
5
  from matplotlib.axes import Axes
8
6
 
9
7
 
10
- def inspect_data(dataframes: Dict[str, pd.DataFrame]):
8
+ def inspect_data(dataframes: Dict[str, pd.DataFrame]) -> None:
11
9
  """
12
10
  Inspect the dataframes.
13
11
 
@@ -16,16 +14,16 @@ def inspect_data(dataframes: Dict[str, pd.DataFrame]):
16
14
  dataframes : Dict[str, pd.DataFrame]
17
15
  Dictionary of method names to dataframes.
18
16
  """
19
- print("=" * 80)
20
- print("Inspecting dataframes...")
21
- print("=" * 80)
17
+ print('=' * 80)
18
+ print('Inspecting dataframes...')
19
+ print('=' * 80)
22
20
  for method, df in dataframes.items():
23
- print(f"Method: {method}")
21
+ print(f'Method: {method}')
24
22
  inspect_df(df)
25
- print("=" * 80)
23
+ print('=' * 80)
26
24
 
27
25
 
28
- def inspect_df(df: pd.DataFrame):
26
+ def inspect_df(df: pd.DataFrame) -> None:
29
27
  """
30
28
  Inspect the dataframe.
31
29
 
@@ -35,24 +33,24 @@ def inspect_df(df: pd.DataFrame):
35
33
  The dataframe to inspect.
36
34
  """
37
35
  print(df.to_markdown())
38
- print("-" * 80)
36
+ print('-' * 80)
39
37
 
40
38
 
41
39
  params_dict = {
42
- "%pn%": "%plot_name%",
43
- "%m%": "%method_name%",
44
- "%n%": "%node%",
45
- "%b%": "%backend%",
46
- "%f%": "%function%",
47
- "%cn%": "%column_name%",
48
- "%pr%": "%precision%",
49
- "%p%": "%decomposition%",
50
- "%d%": "%data_size%",
51
- "%g%": "%nb_gpu%"
40
+ '%pn%': '%plot_name%',
41
+ '%m%': '%method_name%',
42
+ '%n%': '%node%',
43
+ '%b%': '%backend%',
44
+ '%f%': '%function%',
45
+ '%cn%': '%column_name%',
46
+ '%pr%': '%precision%',
47
+ '%p%': '%decomposition%',
48
+ '%d%': '%data_size%',
49
+ '%g%': '%nb_gpu%',
52
50
  }
53
51
 
54
52
 
55
- def expand_label(label_template: str, params: dict) -> str:
53
+ def expand_label(label_template: str, params: dict[str, str]) -> str:
56
54
  """
57
55
  Expand the label template with the provided parameters.
58
56
 
@@ -72,14 +70,20 @@ def expand_label(label_template: str, params: dict) -> str:
72
70
  label_template = label_template.replace(key, value)
73
71
 
74
72
  for key, value in params.items():
75
- label_template = label_template.replace(f"%{key}%", str(value))
73
+ label_template = label_template.replace(f'%{key}%', str(value))
76
74
  return label_template
77
75
 
78
76
 
79
- def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
80
- pdims_strategy: List[str],
81
- print_decompositions: bool, x_col: str,
82
- y_col: str, label_template: str):
77
+ def plot_with_pdims_strategy(
78
+ ax: Axes,
79
+ df: pd.DataFrame,
80
+ method: str,
81
+ pdims_strategy: List[str],
82
+ print_decompositions: bool,
83
+ x_col: str,
84
+ y_col: str,
85
+ label_template: str,
86
+ ) -> Optional[Tuple[List[float], List[float]]]:
83
87
  """
84
88
  Plot the data based on the pdims strategy.
85
89
 
@@ -109,12 +113,12 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
109
113
  Template for plot labels with placeholders.
110
114
  """
111
115
  label_params = {
112
- "plot_name": y_col,
113
- "method_name": method,
114
- "backend": df['backend'].values[0],
115
- "node": df['nodes'].values[0],
116
- "precision": df['precision'].values[0],
117
- "function": df['function'].values[0],
116
+ 'plot_name': y_col,
117
+ 'method_name': method,
118
+ 'backend': df['backend'].values[0],
119
+ 'node': df['nodes'].values[0],
120
+ 'precision': df['precision'].values[0],
121
+ 'function': df['function'].values[0],
118
122
  }
119
123
 
120
124
  if 'plot_fastest' in pdims_strategy:
@@ -126,51 +130,50 @@ def plot_with_pdims_strategy(ax: Axes, df: pd.DataFrame, method: str,
126
130
  group.sort_values(by=[y_col], inplace=True, ascending=True)
127
131
  sorted_dfs.append(group.iloc[0])
128
132
  sorted_df = pd.DataFrame(sorted_dfs)
129
- label_params.update({
130
- "decomposition":
131
- f"{group['px'].values[0]}x{group['py'].values[0]}"
132
- })
133
+ label_params.update({'decomposition': f'{group["px"].values[0]}x{group["py"].values[0]}'})
133
134
  label = expand_label(label_template, label_params)
134
- ax.plot(sorted_df[x_col].values,
135
- sorted_df[y_col],
136
- marker='o',
137
- linestyle='-',
138
- label=label)
135
+ ax.plot(
136
+ sorted_df[x_col].values,
137
+ sorted_df[y_col],
138
+ marker='o',
139
+ linestyle='-',
140
+ label=label,
141
+ )
139
142
  # TODO(wassim) : this is not working very well
140
143
  if print_decompositions:
141
- for j, (px, py) in enumerate(zip(sorted_df['px'],
142
- sorted_df['py'])):
144
+ for j, (px, py) in enumerate(zip(sorted_df['px'], sorted_df['py'])):
143
145
  ax.annotate(
144
- f"{px}x{py}",
146
+ f'{px}x{py}',
145
147
  (sorted_df[x_col].values[j], sorted_df[y_col].values[j]),
146
- textcoords="offset points",
148
+ textcoords='offset points',
147
149
  xytext=(0, 10),
148
150
  ha='center',
149
- color='red' if j == 0 else 'white')
151
+ color='red' if j == 0 else 'white',
152
+ )
150
153
  return sorted_df[x_col].values, sorted_df[y_col].values
151
154
 
152
- elif any(strategy in pdims_strategy
153
- for strategy in ['plot_all', 'slab_yz', 'slab_xy', 'pencils']):
155
+ elif any(
156
+ strategy in pdims_strategy for strategy in ['plot_all', 'slab_yz', 'slab_xy', 'pencils']
157
+ ):
154
158
  df_decomp = df.groupby(['decomp'])
155
159
  x_values = []
156
160
  y_values = []
157
161
  for _, group in df_decomp:
158
- group.drop_duplicates(subset=[x_col, 'decomp'],
159
- keep='last',
160
- inplace=True)
162
+ group.drop_duplicates(subset=[x_col, 'decomp'], keep='last', inplace=True)
161
163
  group.sort_values(by=[x_col], inplace=True, ascending=False)
162
164
  # filter decomp based on pdims_strategy
163
- if 'plot_all' not in pdims_strategy and group['decomp'].values[
164
- 0] not in pdims_strategy:
165
+ if 'plot_all' not in pdims_strategy and group['decomp'].values[0] not in pdims_strategy:
165
166
  continue
166
167
 
167
- label_params.update({"decomposition": group['decomp'].values[0]})
168
+ label_params.update({'decomposition': group['decomp'].values[0]})
168
169
  label = expand_label(label_template, label_params)
169
- ax.plot(group[x_col].values,
170
- group[y_col],
171
- marker='o',
172
- linestyle='-',
173
- label=label)
170
+ ax.plot(
171
+ group[x_col].values,
172
+ group[y_col],
173
+ marker='o',
174
+ linestyle='-',
175
+ label=label,
176
+ )
174
177
  x_values.extend(group[x_col].values)
175
178
  y_values.extend(group[y_col].values)
176
179
  return x_values, y_values
@@ -204,39 +207,63 @@ def concatenate_csvs(root_dir: str, output_dir: str):
204
207
  if file.endswith('.csv'):
205
208
  csv_file_path = os.path.join(root, file)
206
209
  print(f'Concatenating {csv_file_path}...')
207
- df = pd.read_csv(csv_file_path,
208
- header=None,
209
- names=[
210
- "function", "precision", "x", "y",
211
- "z", "px", "py", "backend", "nodes",
212
- "jit_time", "min_time", "max_time",
213
- "mean_time", "std_div", "last_time",
214
- "generated_code", "argument_size",
215
- "output_size", "temp_size", "flops"
216
- ],
217
- index_col=False)
210
+ df = pd.read_csv(
211
+ csv_file_path,
212
+ header=None,
213
+ names=[
214
+ 'function',
215
+ 'precision',
216
+ 'x',
217
+ 'y',
218
+ 'z',
219
+ 'px',
220
+ 'py',
221
+ 'backend',
222
+ 'nodes',
223
+ 'jit_time',
224
+ 'min_time',
225
+ 'max_time',
226
+ 'mean_time',
227
+ 'std_div',
228
+ 'last_time',
229
+ 'generated_code',
230
+ 'argument_size',
231
+ 'output_size',
232
+ 'temp_size',
233
+ 'flops',
234
+ ],
235
+ index_col=False,
236
+ )
218
237
  if file not in combined_dfs:
219
238
  combined_dfs[file] = df
220
239
  else:
221
- combined_dfs[file] = pd.concat(
222
- [combined_dfs[file], df], ignore_index=True)
240
+ combined_dfs[file] = pd.concat([combined_dfs[file], df], ignore_index=True)
223
241
 
224
242
  # Remove duplicates based on specified columns and save
225
243
  for file_name, combined_df in combined_dfs.items():
226
- combined_df.drop_duplicates(subset=[
227
- "function", "precision", "x", "y", "z", "px", "py", "backend",
228
- "nodes"
229
- ],
230
- keep='last',
231
- inplace=True)
244
+ combined_df.drop_duplicates(
245
+ subset=[
246
+ 'function',
247
+ 'precision',
248
+ 'x',
249
+ 'y',
250
+ 'z',
251
+ 'px',
252
+ 'py',
253
+ 'backend',
254
+ 'nodes',
255
+ ],
256
+ keep='last',
257
+ inplace=True,
258
+ )
232
259
 
233
260
  gpu_output_dir = os.path.join(output_dir, gpu)
234
261
  if not os.path.exists(gpu_output_dir):
235
- print(f"Creating directory {gpu_output_dir}")
262
+ print(f'Creating directory {gpu_output_dir}')
236
263
  os.makedirs(gpu_output_dir)
237
264
 
238
265
  output_file = os.path.join(gpu_output_dir, file_name)
239
- print(f"Writing file to {output_file}...")
266
+ print(f'Writing file to {output_file}...')
240
267
  combined_df.to_csv(output_file, index=False)
241
268
 
242
269
 
@@ -287,42 +314,59 @@ def clean_up_csv(
287
314
  file_name = os.path.splitext(os.path.basename(csv_file))[0]
288
315
  ext = os.path.splitext(os.path.basename(csv_file))[1]
289
316
  if ext != '.csv':
290
- print(f"Ignoring {csv_file} as it is not a CSV file")
317
+ print(f'Ignoring {csv_file} as it is not a CSV file')
291
318
  continue
292
319
 
293
- df = pd.read_csv(csv_file,
294
- header=None,
295
- skiprows=0,
296
- names=[
297
- "function", "precision", "x", "y", "z", "px",
298
- "py", "backend", "nodes", "jit_time", "min_time",
299
- "max_time", "mean_time", "std_div", "last_time",
300
- "generated_code", "argument_size", "output_size",
301
- "temp_size", "flops"
302
- ],
303
- dtype={
304
- "function": str,
305
- "precision": str,
306
- "x": int,
307
- "y": int,
308
- "z": int,
309
- "px": int,
310
- "py": int,
311
- "backend": str,
312
- "nodes": int,
313
- "jit_time": float,
314
- "min_time": float,
315
- "max_time": float,
316
- "mean_time": float,
317
- "std_div": float,
318
- "last_time": float,
319
- "generated_code": float,
320
- "argument_size": float,
321
- "output_size": float,
322
- "temp_size": float,
323
- "flops": float
324
- },
325
- index_col=False)
320
+ df = pd.read_csv(
321
+ csv_file,
322
+ header=None,
323
+ skiprows=0,
324
+ names=[
325
+ 'function',
326
+ 'precision',
327
+ 'x',
328
+ 'y',
329
+ 'z',
330
+ 'px',
331
+ 'py',
332
+ 'backend',
333
+ 'nodes',
334
+ 'jit_time',
335
+ 'min_time',
336
+ 'max_time',
337
+ 'mean_time',
338
+ 'std_div',
339
+ 'last_time',
340
+ 'generated_code',
341
+ 'argument_size',
342
+ 'output_size',
343
+ 'temp_size',
344
+ 'flops',
345
+ ],
346
+ dtype={
347
+ 'function': str,
348
+ 'precision': str,
349
+ 'x': int,
350
+ 'y': int,
351
+ 'z': int,
352
+ 'px': int,
353
+ 'py': int,
354
+ 'backend': str,
355
+ 'nodes': int,
356
+ 'jit_time': float,
357
+ 'min_time': float,
358
+ 'max_time': float,
359
+ 'mean_time': float,
360
+ 'std_div': float,
361
+ 'last_time': float,
362
+ 'generated_code': float,
363
+ 'argument_size': float,
364
+ 'output_size': float,
365
+ 'temp_size': float,
366
+ 'flops': float,
367
+ },
368
+ index_col=False,
369
+ )
326
370
 
327
371
  # Filter precisions
328
372
  if precisions:
@@ -360,18 +404,32 @@ def clean_up_csv(
360
404
  df['output_size'] = df['output_size'] / factor
361
405
  df['temp_size'] = df['temp_size'] / factor
362
406
  # in case of the same test is run multiple times, keep the last one
363
- df = df.drop_duplicates(subset=[
364
- "function", "precision", "x", "y", "z", "px", "py", "backend",
365
- "nodes"
366
- ],
367
- keep='last')
407
+ df = df.drop_duplicates(
408
+ subset=[
409
+ 'function',
410
+ 'precision',
411
+ 'x',
412
+ 'y',
413
+ 'z',
414
+ 'px',
415
+ 'py',
416
+ 'backend',
417
+ 'nodes',
418
+ ],
419
+ keep='last',
420
+ )
368
421
 
369
422
  df['gpus'] = df['px'] * df['py']
370
423
 
371
424
  if gpus:
372
425
  df = df[df['gpus'].isin(gpus)]
373
426
 
374
- if 'plot_all' in pdims_strategy or 'slab_yz' in pdims_strategy or 'slab_xy' in pdims_strategy or 'pencils' in pdims_strategy:
427
+ if (
428
+ 'plot_all' in pdims_strategy
429
+ or 'slab_yz' in pdims_strategy
430
+ or 'slab_xy' in pdims_strategy
431
+ or 'pencils' in pdims_strategy
432
+ ):
375
433
 
376
434
  def get_decomp_from_px_py(row):
377
435
  if row['px'] > 1 and row['py'] == 1:
@@ -383,7 +441,7 @@ def clean_up_csv(
383
441
 
384
442
  df['decomp'] = df.apply(get_decomp_from_px_py, axis=1)
385
443
  df.drop(columns=['px', 'py'], inplace=True)
386
- if not 'plot_all' in pdims_strategy:
444
+ if 'plot_all' not in pdims_strategy:
387
445
  df = df[df['decomp'].isin(pdims_strategy)]
388
446
 
389
447
  # check available gpus in dataset
@@ -395,17 +453,18 @@ def clean_up_csv(
395
453
  else:
396
454
  dataframes[file_name] = pd.concat([dataframes[file_name], df])
397
455
 
398
- print(f"requested GPUS: {gpus} available GPUS: {available_gpu_counts}")
399
- print(
400
- f"requested data sizes: {data_sizes} available data sizes: {available_data_sizes}"
401
- )
456
+ print(f'requested GPUS: {gpus} available GPUS: {available_gpu_counts}')
457
+ print(f'requested data sizes: {data_sizes} available data sizes: {available_data_sizes}')
402
458
 
403
- available_gpu_counts = (available_gpu_counts if gpus is None else [
404
- gpu for gpu in gpus if gpu in available_gpu_counts
405
- ])
406
- available_data_sizes = (available_data_sizes if data_sizes is None else [
407
- data_size for data_size in data_sizes
408
- if data_size in available_data_sizes
409
- ])
459
+ available_gpu_counts = (
460
+ available_gpu_counts
461
+ if gpus is None
462
+ else [gpu for gpu in gpus if gpu in available_gpu_counts]
463
+ )
464
+ available_data_sizes = (
465
+ available_data_sizes
466
+ if data_sizes is None
467
+ else [data_size for data_size in data_sizes if data_size in available_data_sizes]
468
+ )
410
469
 
411
470
  return dataframes, available_gpu_counts, available_data_sizes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jax_hpc_profiler
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: HPC Plotter and profiler for benchmarking data made for JAX
5
5
  Author: Wassim Kabalan
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -0,0 +1,12 @@
1
+ jax_hpc_profiler/__init__.py,sha256=c2n33ZXVgUS8vo5xAEW-TcSi_UzJp616KdGEb3iO6p4,388
2
+ jax_hpc_profiler/create_argparse.py,sha256=J1RF4n2e85QReoI_fqXxK5BMAUgzueHmObKOh4YHopE,5821
3
+ jax_hpc_profiler/main.py,sha256=YPLkZCmtjzNoDrzTA4CWL8y39Spz3qbCS91eP2pqP5Y,2224
4
+ jax_hpc_profiler/plotting.py,sha256=Lg157H3mrF3zHc4BIplddKu9f0viQkaQhtCCAQBxinE,9167
5
+ jax_hpc_profiler/timer.py,sha256=0lbJgNh3GT1dFOpNOA4Fwvsm9JNp-J1xDdLFaaQ6jaY,10237
6
+ jax_hpc_profiler/utils.py,sha256=IfGDbKldJXiDhxb02IxmQV51SFIBYLDUL7Se_OtEOkc,14963
7
+ jax_hpc_profiler-0.2.13.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
+ jax_hpc_profiler-0.2.13.dist-info/METADATA,sha256=YyHfP98Vz8ya23YsRPV2rehbRoFsO3pziOgnoX5DitE,49186
9
+ jax_hpc_profiler-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ jax_hpc_profiler-0.2.13.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
+ jax_hpc_profiler-0.2.13.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
+ jax_hpc_profiler-0.2.13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,12 +0,0 @@
1
- jax_hpc_profiler/__init__.py,sha256=yDWt2S0xJZeS6YLBFvyPj1p5txCgFx2kCxJzVZZcdtI,367
2
- jax_hpc_profiler/create_argparse.py,sha256=6DVpYuj908L01vk09-l7BLU6dW4OGgTepFR123fsawM,6314
3
- jax_hpc_profiler/main.py,sha256=dwOik2rJw5YV6ocQ-EE32iFOPlq2_3CHHAAuJJFt65Q,2286
4
- jax_hpc_profiler/plotting.py,sha256=R0mjUhV_Q-qi02mlxWiR241sxr58USBSykwFdjBa-oM,9484
5
- jax_hpc_profiler/timer.py,sha256=4zc5HlJwepMK633BDz0iLTLWcLsvPdd6M1SL0-qs4js,10554
6
- jax_hpc_profiler/utils.py,sha256=7i8qPfKogp8nGaGdyJ2-fbQomhIZqn73PQ14qldpFTc,14657
7
- jax_hpc_profiler-0.2.11.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
8
- jax_hpc_profiler-0.2.11.dist-info/METADATA,sha256=mTTBpWxSRjhWAI9jZ-TdvwxFlpm9j2BhCe77TfbITTM,49186
9
- jax_hpc_profiler-0.2.11.dist-info/WHEEL,sha256=tTnHoFhvKQHCh4jz3yCn0WPTYIy7wXx3CJtJ7SJGV7c,91
10
- jax_hpc_profiler-0.2.11.dist-info/entry_points.txt,sha256=_cFlxSINscX3ZyNiklfjyOOO7vNkddhoYy_v1JQHSO4,51
11
- jax_hpc_profiler-0.2.11.dist-info/top_level.txt,sha256=DKAhVKDwkerhth-xo7oKFSnnKE0Xm46m94b06vZksA4,17
12
- jax_hpc_profiler-0.2.11.dist-info/RECORD,,