jax-hpc-profiler 0.2.12__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,23 +4,23 @@ from typing import Dict, List, Optional
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
6
  import pandas as pd
7
- import seaborn as sns
7
+ from adjustText import adjust_text
8
8
  from matplotlib.axes import Axes
9
9
  from matplotlib.patches import FancyBboxPatch
10
10
 
11
- from .utils import clean_up_csv, inspect_df, plot_with_pdims_strategy
11
+ from .utils import clean_up_csv, plot_with_pdims_strategy
12
12
 
13
- np.seterr(divide="ignore")
13
+ np.seterr(divide='ignore')
14
14
 
15
15
 
16
16
  def configure_axes(
17
17
  ax: Axes,
18
18
  x_values: List[int],
19
19
  y_values: List[float],
20
- title: str,
20
+ title: Optional[str],
21
21
  xlabel: str,
22
22
  plotting_memory: bool = False,
23
- memory_units: str = "bytes",
23
+ memory_units: str = 'bytes',
24
24
  ):
25
25
  """
26
26
  Configure the axes for the plot.
@@ -36,33 +36,32 @@ def configure_axes(
36
36
  xlabel : str
37
37
  The label for the x-axis.
38
38
  """
39
- ylabel = ("Time (milliseconds)"
40
- if not plotting_memory else f"Memory ({memory_units})")
41
- f2 = lambda x: np.log2(x)
42
- g2 = lambda x: 2**x
39
+ ylabel = 'Time (milliseconds)' if not plotting_memory else f'Memory ({memory_units})'
40
+
41
+ def f2(x):
42
+ return np.log2(x)
43
+
44
+ def g2(x):
45
+ return 2**x
46
+
43
47
  ax.set_xlim([min(x_values), max(x_values)])
44
48
  y_min, y_max = min(y_values) * 0.6, max(y_values) * 1.1
45
49
  ax.set_title(title)
46
50
  ax.set_ylim([y_min, y_max])
47
- ax.set_xscale("function", functions=(f2, g2))
51
+ ax.set_xscale('function', functions=(f2, g2))
48
52
  if not plotting_memory:
49
- ax.set_yscale("symlog")
53
+ ax.set_yscale('symlog')
50
54
  time_ticks = [
51
- 10**t for t in range(int(np.floor(np.log10(y_min))), 1 +
52
- int(np.ceil(np.log10(y_max))))
55
+ 10**t for t in range(int(np.floor(np.log10(y_min))), 1 + int(np.ceil(np.log10(y_max))))
53
56
  ]
54
57
  ax.set_yticks(time_ticks)
55
58
  ax.set_xticks(x_values)
56
59
  ax.set_xlabel(xlabel)
57
60
  ax.set_ylabel(ylabel)
58
61
  for x_value in x_values:
59
- ax.axvline(x=x_value, color="gray", linestyle="--", alpha=0.5)
62
+ ax.axvline(x=x_value, color='gray', linestyle='--', alpha=0.5)
60
63
  ax.legend(
61
- loc="lower center",
62
- bbox_to_anchor=(0.5, 0.05),
63
- ncol=4,
64
- fontsize="x-large",
65
- prop={"size": 14},
64
+ loc='best',
66
65
  )
67
66
 
68
67
 
@@ -80,10 +79,10 @@ def plot_scaling(
80
79
  backends: Optional[List[str]] = None,
81
80
  precisions: Optional[List[str]] = None,
82
81
  functions: Optional[List[str]] = None,
83
- plot_columns: List[str] = ["mean_time"],
84
- memory_units: str = "bytes",
85
- label_text: str = "plot",
86
- pdims_strategy: List[str] = ["plot_fastest"],
82
+ plot_columns: List[str] = ['mean_time'],
83
+ memory_units: str = 'bytes',
84
+ label_text: str = 'plot',
85
+ pdims_strategy: List[str] = ['plot_fastest'],
87
86
  ):
88
87
  """
89
88
  General scaling plot function based on the number of GPUs or data size.
@@ -115,7 +114,7 @@ def plot_scaling(
115
114
  """
116
115
 
117
116
  if dark_bg:
118
- plt.style.use("dark_background")
117
+ plt.style.use('dark_background')
119
118
 
120
119
  num_subplots = len(fixed_sizes)
121
120
  num_rows = int(np.ceil(np.sqrt(num_subplots)))
@@ -133,28 +132,26 @@ def plot_scaling(
133
132
  x_values = []
134
133
  y_values = []
135
134
  for method, df in dataframes.items():
136
-
137
135
  filtered_method_df = df[df[fixed_column] == int(fixed_size)]
138
136
  if filtered_method_df.empty:
139
137
  continue
140
- filtered_method_df = filtered_method_df.sort_values(
141
- by=[size_column])
142
- functions = (pd.unique(filtered_method_df["function"])
143
- if functions is None else functions)
144
- precisions = (pd.unique(filtered_method_df["precision"])
145
- if precisions is None else precisions)
146
- backends = (pd.unique(filtered_method_df["backend"])
147
- if backends is None else backends)
148
-
149
- combinations = product(backends, precisions, functions,
150
- plot_columns)
138
+ filtered_method_df = filtered_method_df.sort_values(by=[size_column])
139
+ functions = (
140
+ pd.unique(filtered_method_df['function']) if functions is None else functions
141
+ )
142
+ precisions = (
143
+ pd.unique(filtered_method_df['precision']) if precisions is None else precisions
144
+ )
145
+ backends = pd.unique(filtered_method_df['backend']) if backends is None else backends
151
146
 
152
- for backend, precision, function, plot_column in combinations:
147
+ combinations = product(backends, precisions, functions, plot_columns)
153
148
 
149
+ for backend, precision, function, plot_column in combinations:
154
150
  filtered_params_df = filtered_method_df[
155
- (filtered_method_df["backend"] == backend)
156
- & (filtered_method_df["precision"] == precision)
157
- & (filtered_method_df["function"] == function)]
151
+ (filtered_method_df['backend'] == backend)
152
+ & (filtered_method_df['precision'] == precision)
153
+ & (filtered_method_df['function'] == function)
154
+ ]
158
155
  if filtered_params_df.empty:
159
156
  continue
160
157
  x_vals, y_vals = plot_with_pdims_strategy(
@@ -172,12 +169,13 @@ def plot_scaling(
172
169
  y_values.extend(y_vals)
173
170
 
174
171
  if len(x_values) != 0:
175
- plotting_memory = "time" not in plot_columns[0].lower()
172
+ plotting_memory = 'time' not in plot_columns[0].lower()
173
+ figure_title = f'{title} {fixed_size}' if title is not None else None
176
174
  configure_axes(
177
175
  ax,
178
176
  x_values,
179
177
  y_values,
180
- f"{title} {fixed_size}",
178
+ figure_title,
181
179
  xlabel,
182
180
  plotting_memory,
183
181
  memory_units,
@@ -187,17 +185,12 @@ def plot_scaling(
187
185
  fig.delaxes(axs[i])
188
186
 
189
187
  fig.tight_layout()
190
- rect = FancyBboxPatch((0.1, 0.1),
191
- 0.8,
192
- 0.8,
193
- boxstyle="round,pad=0.02",
194
- ec="black",
195
- fc="none")
188
+ rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
196
189
  fig.patches.append(rect)
197
190
  if output is None:
198
191
  plt.show()
199
192
  else:
200
- plt.savefig(output, bbox_inches="tight", transparent=True)
193
+ plt.savefig(output)
201
194
 
202
195
 
203
196
  def plot_strong_scaling(
@@ -207,14 +200,14 @@ def plot_strong_scaling(
207
200
  functions: Optional[List[str]] = None,
208
201
  precisions: Optional[List[str]] = None,
209
202
  pdims: Optional[List[str]] = None,
210
- pdims_strategy: List[str] = ["plot_fastest"],
203
+ pdims_strategy: List[str] = ['plot_fastest'],
211
204
  print_decompositions: bool = False,
212
205
  backends: Optional[List[str]] = None,
213
- plot_columns: List[str] = ["mean_time"],
214
- memory_units: str = "bytes",
215
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
216
- xlabel: str = "Number of GPUs",
217
- title: str = "Data sizes",
206
+ plot_columns: List[str] = ['mean_time'],
207
+ memory_units: str = 'bytes',
208
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
209
+ xlabel: str = 'Number of GPUs',
210
+ title: str = 'Data sizes',
218
211
  figure_size: tuple = (6, 4),
219
212
  dark_bg: bool = False,
220
213
  output: Optional[str] = None,
@@ -235,14 +228,14 @@ def plot_strong_scaling(
235
228
  memory_units,
236
229
  )
237
230
  if len(dataframes) == 0:
238
- print(f"No dataframes found for the given arguments. Exiting...")
231
+ print('No dataframes found for the given arguments. Exiting...')
239
232
  return
240
233
 
241
234
  plot_scaling(
242
235
  dataframes,
243
236
  available_data_sizes,
244
- "gpus",
245
- "x",
237
+ 'gpus',
238
+ 'x',
246
239
  xlabel,
247
240
  title,
248
241
  figure_size,
@@ -266,20 +259,209 @@ def plot_weak_scaling(
266
259
  functions: Optional[List[str]] = None,
267
260
  precisions: Optional[List[str]] = None,
268
261
  pdims: Optional[List[str]] = None,
269
- pdims_strategy: List[str] = ["plot_fastest"],
262
+ pdims_strategy: List[str] = ['plot_fastest'],
263
+ print_decompositions: bool = False,
264
+ backends: Optional[List[str]] = None,
265
+ plot_columns: List[str] = ['mean_time'],
266
+ memory_units: str = 'bytes',
267
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
268
+ xlabel: str = 'Number of GPUs',
269
+ title: Optional[str] = None,
270
+ figure_size: tuple = (6, 4),
271
+ dark_bg: bool = False,
272
+ output: Optional[str] = None,
273
+ ideal_line: bool = False,
274
+ reverse_axes: bool = False,
275
+ ):
276
+ """
277
+ Plot true weak scaling: runtime vs GPUs for explicit (gpus, data size) sequences.
278
+
279
+ Both ``fixed_gpu_size`` and ``fixed_data_size`` must be provided and have the same length,
280
+ representing explicit weak-scaling pairs (gpus[i], data_size[i]).
281
+
282
+ reverse_axes:
283
+ - False (default): x-axis is GPUs, y-axis is time; points are annotated with
284
+ ``N=<data_size>``.
285
+ - True: x-axis is data size, y-axis is time; points are annotated with ``GPUs=<gpu_count>``.
286
+ """
287
+ if fixed_gpu_size is None or fixed_data_size is None:
288
+ raise ValueError(
289
+ 'Weak scaling requires both fixed_gpu_size (gpus) and fixed_data_size (problem sizes).'
290
+ )
291
+ if len(fixed_gpu_size) != len(fixed_data_size):
292
+ raise ValueError(
293
+ 'Weak scaling requires fixed_gpu_size and fixed_data_size lists of equal length.'
294
+ )
295
+
296
+ gpu_to_data = {int(g): int(d) for g, d in zip(fixed_gpu_size, fixed_data_size)}
297
+ data_to_gpu = {int(d): int(g) for g, d in zip(fixed_gpu_size, fixed_data_size)}
298
+ x_col = 'x' if reverse_axes else 'gpus'
299
+
300
+ dataframes, _, _ = clean_up_csv(
301
+ csv_files,
302
+ precisions,
303
+ functions,
304
+ fixed_gpu_size,
305
+ fixed_data_size,
306
+ pdims,
307
+ pdims_strategy,
308
+ backends,
309
+ memory_units,
310
+ )
311
+ if len(dataframes) == 0:
312
+ print('No dataframes found for the given arguments. Exiting...')
313
+ return
314
+
315
+ if dark_bg:
316
+ plt.style.use('dark_background')
317
+
318
+ fig, ax = plt.subplots(figsize=figure_size)
319
+
320
+ x_values: List[float] = []
321
+ y_values: List[float] = []
322
+ annotations: List = []
323
+ ideal_line_plotted = False
324
+
325
+ for method, df in dataframes.items():
326
+ # Determine parameter sets from the filtered dataframe if not provided
327
+ local_functions = pd.unique(df['function']) if functions is None else functions
328
+ local_precisions = pd.unique(df['precision']) if precisions is None else precisions
329
+ local_backends = pd.unique(df['backend']) if backends is None else backends
330
+
331
+ combinations = product(local_backends, local_precisions, local_functions, plot_columns)
332
+
333
+ for backend, precision, function, plot_column in combinations:
334
+ base_df = df[
335
+ (df['backend'] == backend)
336
+ & (df['precision'] == precision)
337
+ & (df['function'] == function)
338
+ ]
339
+ if base_df.empty:
340
+ continue
341
+
342
+ # Keep only rows matching any of the (gpus, x) pairs
343
+ mask = pd.Series(False, index=base_df.index)
344
+ for g, d in zip(fixed_gpu_size, fixed_data_size):
345
+ mask |= (base_df['gpus'] == int(g)) & (base_df['x'] == int(d))
346
+
347
+ filtered_params_df = base_df[mask]
348
+ if filtered_params_df.empty:
349
+ continue
350
+
351
+ x_vals, y_vals = plot_with_pdims_strategy(
352
+ ax,
353
+ filtered_params_df,
354
+ method,
355
+ pdims_strategy,
356
+ print_decompositions,
357
+ x_col,
358
+ plot_column,
359
+ label_text,
360
+ )
361
+ if x_vals is None or len(x_vals) == 0:
362
+ continue
363
+
364
+ x_arr = np.asarray(x_vals).reshape(-1)
365
+ y_arr = np.asarray(y_vals).reshape(-1)
366
+
367
+ # Annotate every point with data size or GPU count depending on axis choice.
368
+ # Use plain data coordinates for the text; adjust_text will then only move
369
+ # the labels slightly (mostly vertically) to avoid overlap.
370
+ for xv, yv in zip(x_arr, y_arr):
371
+ if reverse_axes:
372
+ gpu = data_to_gpu.get(int(xv))
373
+ if gpu is None:
374
+ continue
375
+ label = f'GPUs={gpu}'
376
+ else:
377
+ data_size = gpu_to_data.get(int(xv))
378
+ if data_size is None:
379
+ continue
380
+ label = f'N={data_size}'
381
+
382
+ text_obj = ax.text(
383
+ float(xv),
384
+ float(yv),
385
+ label,
386
+ ha='center',
387
+ va='bottom',
388
+ fontsize='small',
389
+ clip_on=True,
390
+ )
391
+ annotations.append(text_obj)
392
+
393
+ x_values.extend(x_arr.tolist())
394
+ y_values.extend(y_arr.tolist())
395
+
396
+ if ideal_line and not ideal_line_plotted:
397
+ # Use the smallest x value in this curve as baseline
398
+ baseline_index = np.argmin(x_arr)
399
+ baseline_y = y_arr[baseline_index]
400
+ ax.hlines(
401
+ baseline_y,
402
+ xmin=float(np.min(x_arr)),
403
+ xmax=float(np.max(x_arr)),
404
+ colors='gray',
405
+ linestyles='dashed',
406
+ label='Ideal weak scaling',
407
+ )
408
+ ideal_line_plotted = True
409
+ y_values.append(float(baseline_y))
410
+
411
+ if x_values:
412
+ plotting_memory = 'time' not in plot_columns[0].lower()
413
+ figure_title = title if title is not None else 'Weak scaling'
414
+ configure_axes(
415
+ ax,
416
+ x_values,
417
+ y_values,
418
+ figure_title,
419
+ xlabel,
420
+ plotting_memory,
421
+ memory_units,
422
+ )
423
+ if annotations:
424
+ ax.figure.canvas.draw()
425
+ adjust_text(
426
+ annotations,
427
+ ax=ax,
428
+ # keep points aligned in x, only allow vertical motion
429
+ only_move={'text': 'y', 'static': 'y'},
430
+ expand=(1.02, 1.05),
431
+ force_text=(0.08, 0.2),
432
+ max_move=(0, 30),
433
+ )
434
+
435
+ fig.tight_layout()
436
+ rect = FancyBboxPatch((0.1, 0.1), 0.8, 0.8, boxstyle='round,pad=0.02', ec='black', fc='none')
437
+ fig.patches.append(rect)
438
+ if output is None:
439
+ plt.show()
440
+ else:
441
+ plt.savefig(output)
442
+
443
+
444
+ def plot_weak_fixed_scaling(
445
+ csv_files: List[str],
446
+ fixed_gpu_size: Optional[List[int]] = None,
447
+ fixed_data_size: Optional[List[int]] = None,
448
+ functions: Optional[List[str]] = None,
449
+ precisions: Optional[List[str]] = None,
450
+ pdims: Optional[List[str]] = None,
451
+ pdims_strategy: List[str] = ['plot_fastest'],
270
452
  print_decompositions: bool = False,
271
453
  backends: Optional[List[str]] = None,
272
- plot_columns: List[str] = ["mean_time"],
273
- memory_units: str = "bytes",
274
- label_text: str = "%m%-%f%-%pn%-%pr%-%b%-%p%-%n%",
275
- xlabel: str = "Data sizes",
276
- title: str = "Number of GPUs",
454
+ plot_columns: List[str] = ['mean_time'],
455
+ memory_units: str = 'bytes',
456
+ label_text: str = '%m%-%f%-%pn%-%pr%-%b%-%p%-%n%',
457
+ xlabel: str = 'Data sizes',
458
+ title: str = 'Number of GPUs',
277
459
  figure_size: tuple = (6, 4),
278
460
  dark_bg: bool = False,
279
461
  output: Optional[str] = None,
280
462
  ):
281
463
  """
282
- Plot weak scaling based on the data size.
464
+ Plot size scaling at fixed GPU count (previous weak-scaling behavior).
283
465
  """
284
466
  dataframes, available_gpu_counts, _ = clean_up_csv(
285
467
  csv_files,
@@ -293,14 +475,14 @@ def plot_weak_scaling(
293
475
  memory_units,
294
476
  )
295
477
  if len(dataframes) == 0:
296
- print(f"No dataframes found for the given arguments. Exiting...")
478
+ print('No dataframes found for the given arguments. Exiting...')
297
479
  return
298
480
 
299
481
  plot_scaling(
300
482
  dataframes,
301
483
  available_gpu_counts,
302
- "x",
303
- "gpus",
484
+ 'x',
485
+ 'gpus',
304
486
  xlabel,
305
487
  title,
306
488
  figure_size,