onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +22 -5
  3. onnx_diagnostic/ext_test_case.py +31 -0
  4. onnx_diagnostic/helpers/cache_helper.py +23 -12
  5. onnx_diagnostic/helpers/config_helper.py +16 -1
  6. onnx_diagnostic/helpers/log_helper.py +308 -83
  7. onnx_diagnostic/helpers/rt_helper.py +11 -1
  8. onnx_diagnostic/helpers/torch_helper.py +7 -3
  9. onnx_diagnostic/tasks/__init__.py +2 -0
  10. onnx_diagnostic/tasks/text_generation.py +17 -8
  11. onnx_diagnostic/tasks/text_to_image.py +91 -0
  12. onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
  15. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
  16. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  17. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  18. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  19. onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
  20. onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
  23. onnx_diagnostic/torch_models/validate.py +36 -12
  24. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
  25. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
  26. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,63 @@ BUCKET_SCALES_VALUES = np.array(
21
21
  BUCKET_SCALES = BUCKET_SCALES_VALUES / 100 + 1
22
22
 
23
23
 
24
+ def filter_data(
25
+ df: pandas.DataFrame,
26
+ filter_in: Optional[str] = None,
27
+ filter_out: Optional[str] = None,
28
+ verbose: int = 0,
29
+ ) -> pandas.DataFrame:
30
+ """
31
+ Argument `filter` follows the syntax
32
+ ``<column1>:<fmt1>//<column2>:<fmt2>``.
33
+
34
+ The format is the following:
35
+
36
+ * a value or a set of values separated by ``;``
37
+ """
38
+ if not filter_in and not filter_out:
39
+ return df
40
+
41
+ def _f(fmt):
42
+ cond = {}
43
+ if isinstance(fmt, str):
44
+ cols = fmt.split("//")
45
+ for c in cols:
46
+ assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}"
47
+ spl = c.split(":")
48
+ assert len(spl) == 2, f"Unexpected value {c!r} in fmt={fmt!r}"
49
+ name, fil = spl
50
+ cond[name] = set(fil.split(";"))
51
+ return cond
52
+
53
+ if filter_in:
54
+ cond = _f(filter_in)
55
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_in!r}"
56
+ for k, v in cond.items():
57
+ if k not in df.columns:
58
+ continue
59
+ if verbose:
60
+ print(
61
+ f"[_filter_data] filter in column {k!r}, "
62
+ f"values {v!r} among {set(df[k].astype(str))}"
63
+ )
64
+ df = df[df[k].astype(str).isin(v)]
65
+
66
+ if filter_out:
67
+ cond = _f(filter_out)
68
+ assert isinstance(cond, dict), f"Unexpected type {type(cond)} for fmt={filter_out!r}"
69
+ for k, v in cond.items():
70
+ if k not in df.columns:
71
+ continue
72
+ if verbose:
73
+ print(
74
+ f"[_filter_data] filter out column {k!r}, "
75
+ f"values {v!r} among {set(df[k].astype(str))}"
76
+ )
77
+ df = df[~df[k].astype(str).isin(v)]
78
+ return df
79
+
80
+
24
81
  def enumerate_csv_files(
25
82
  data: Union[
26
83
  pandas.DataFrame, List[Union[str, Tuple[str, str]]], str, Tuple[str, str, str, str]
@@ -118,7 +175,8 @@ def open_dataframe(
118
175
  data: Union[str, Tuple[str, str, str, str], pandas.DataFrame],
119
176
  ) -> pandas.DataFrame:
120
177
  """
121
- Opens a filename.
178
+ Opens a filename defined by function
179
+ :func:`onnx_diagnostic.helpers.log_helper.enumerate_csv_files`.
122
180
 
123
181
  :param data: a dataframe, a filename, a tuple indicating the file is coming
124
182
  from a zip file
@@ -259,7 +317,7 @@ def apply_excel_style(
259
317
  co: Dict[int, int] = {}
260
318
  sizes: Dict[int, int] = {}
261
319
  cols = set()
262
- for i in range(1, n_rows):
320
+ for i in range(1, n_rows + 1):
263
321
  for j, cell in enumerate(sheet[i]):
264
322
  if j > n_cols:
265
323
  break
@@ -277,7 +335,7 @@ def apply_excel_style(
277
335
  c = get_column_letter(k)
278
336
  sheet.column_dimensions[c].width = 15
279
337
 
280
- for i in range(1, n_rows):
338
+ for i in range(1, n_rows + 1):
281
339
  for j, cell in enumerate(sheet[i]):
282
340
  if j > n_cols:
283
341
  break
@@ -333,18 +391,85 @@ def apply_excel_style(
333
391
  class CubePlot:
334
392
  """
335
393
  Creates a plot.
394
+
395
+ :param df: dataframe
396
+ :param kind: kind of graph to plot, bar, barh, line
397
+ :param split: draw a graph per line in the dataframe
398
+ :param timeseries: this assumes the time is one level of the columns,
399
+ this argument indices the level name
336
400
  """
337
401
 
402
+ KINDS = {"bar", "barh", "line"}
403
+
404
+ @classmethod
405
+ def group_columns(
406
+ cls, columns: List[str], sep: str = "/", depth: int = 2
407
+ ) -> List[List[str]]:
408
+ """Groups columns to have nice display."""
409
+ res: Dict[str, List[str]] = {}
410
+ for c in columns:
411
+ p = c.split("/")
412
+ k = "/".join(p[:depth])
413
+ if k not in res:
414
+ res[k] = []
415
+ res[k].append(c)
416
+ new_res: Dict[str, List[str]] = {}
417
+ for k, v in res.items():
418
+ if len(v) >= 3:
419
+ new_res[k] = v
420
+ else:
421
+ if "0" not in new_res:
422
+ new_res["0"] = []
423
+ new_res["0"].extend(v)
424
+ groups: List[List[str]] = [sorted(v) for k, v in sorted(new_res.items())]
425
+ if depth <= 1:
426
+ return groups
427
+ new_groups: List[List[str]] = []
428
+ for v in groups:
429
+ if len(v) >= 6:
430
+ new_groups.extend(cls.group_columns(v, depth=1, sep=sep))
431
+ else:
432
+ new_groups.append(v)
433
+ return new_groups
434
+
338
435
  def __init__(
339
- self, df: pandas.DataFrame, kind: str = "bar", orientation="col", split: bool = True
436
+ self,
437
+ df: pandas.DataFrame,
438
+ kind: str = "bar",
439
+ orientation="col",
440
+ split: bool = True,
441
+ timeseries: Optional[str] = None,
340
442
  ):
443
+ assert (
444
+ not timeseries or timeseries in df.columns.names
445
+ ), f"Level {timeseries!r} is not part of the columns levels {df.columns.names}"
446
+ assert (
447
+ kind in self.__class__.KINDS
448
+ ), f"Unexpected kind={kind!r} not in {self.__class__.KINDS}"
449
+ assert split, f"split={split} not implemented"
450
+ assert (
451
+ not timeseries or orientation == "row"
452
+ ), f"orientation={orientation!r} must be 'row' for timeseries"
341
453
  self.df = df.copy()
342
454
  self.kind = kind
343
455
  self.orientation = orientation
344
456
  self.split = split
457
+ self.timeseries = timeseries
458
+
459
+ if timeseries:
460
+ if isinstance(self.df.columns, pandas.MultiIndex):
461
+ index_time = list(self.df.columns.names).index(self.timeseries)
345
462
 
346
- if isinstance(self.df.columns, pandas.MultiIndex):
347
- self.df.columns = ["/".join(map(str, i)) for i in self.df.columns]
463
+ def _drop(t, i=index_time):
464
+ return (*t[:i], *t[i + 1 :])
465
+
466
+ self.df.columns = pandas.MultiIndex.from_tuples(
467
+ [("/".join(map(str, _drop(i))), i[index_time]) for i in self.df.columns],
468
+ names=["metric", timeseries],
469
+ )
470
+ else:
471
+ if isinstance(self.df.columns, pandas.MultiIndex):
472
+ self.df.columns = ["/".join(map(str, i)) for i in self.df.columns]
348
473
  if isinstance(self.df.index, pandas.MultiIndex):
349
474
  self.df.index = ["/".join(map(str, i)) for i in self.df.index]
350
475
 
@@ -354,85 +479,129 @@ class CubePlot:
354
479
 
355
480
  def to_images(
356
481
  self, verbose: int = 0, merge: bool = True, title_suffix: Optional[str] = None
357
- ):
482
+ ) -> List[bytes]:
358
483
  """
359
484
  Converts data into plots and images.
485
+
486
+ :param verbose: verbosity
487
+ :param merge: returns all graphs in a single image (True)
488
+ or an image for every graph (False)
489
+ :param title_suffix: prefix for the title of every graph
490
+ :return: list of binary images (format PNG)
360
491
  """
361
- import matplotlib.pyplot as plt
492
+ if self.kind in ("barh", "bar"):
493
+ return self._to_images_bar(verbose=verbose, merge=merge, title_suffix=title_suffix)
494
+ if self.kind == "line":
495
+ return self._to_images_line(
496
+ verbose=verbose, merge=merge, title_suffix=title_suffix
497
+ )
498
+ raise AssertionError(f"self.kind={self.kind!r} not implemented")
362
499
 
363
- df = self.df.T if self.orientation == "row" else self.df
364
- imgs = []
500
+ @classmethod
501
+ def _make_loop(cls, ensemble, verbose):
365
502
  if verbose:
366
503
  from tqdm import tqdm
367
504
 
368
- loop = tqdm(df.columns)
505
+ loop = tqdm(ensemble)
369
506
  else:
370
- loop = df.columns
507
+ loop = ensemble
508
+ return loop
509
+
510
+ def _to_images_bar(
511
+ self, verbose: int = 0, merge: bool = True, title_suffix: Optional[str] = None
512
+ ) -> List[bytes]:
513
+ assert merge, f"merge={merge} not implemented yet"
514
+ import matplotlib.pyplot as plt
515
+
516
+ df = self.df.T if self.orientation == "row" else self.df
371
517
  title_suffix = f"\n{title_suffix}" if title_suffix else ""
372
- if merge:
373
- nn = len(df.columns) // 2
374
- nn += nn % 2
375
- fig, axs = plt.subplots(nn, 2, figsize=(12, 3 * nn * df.shape[0] / 12))
376
- pos = 0
377
- for c in loop:
378
- ax = axs[pos // 2, pos % 2]
518
+
519
+ n_cols = 3
520
+ nn = df.shape[1] // n_cols
521
+ nn += int(df.shape[1] % n_cols != 0)
522
+ fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 5))
523
+ pos = 0
524
+ imgs = []
525
+ for c in self._make_loop(df.columns, verbose):
526
+ ax = axs[pos // n_cols, pos % n_cols]
527
+ (
379
528
  df[c].plot.barh(title=f"{c}{title_suffix}", ax=ax)
380
- ax.tick_params(axis="both", which="major", labelsize=8)
381
- ax.grid(True)
382
- pos += 1 # noqa: SIM113
383
- fig.tight_layout()
384
- imgdata = io.BytesIO()
385
- fig.savefig(imgdata, format="png")
386
- imgs.append(imgdata.getvalue())
387
- plt.close()
388
- else:
389
- for c in loop:
390
- fig, ax = plt.subplots(1, 1, figsize=(3, 3))
391
- df[c].plot.barh(title=c, ax=ax)
392
- ax.tick_params(axis="both", which="major", labelsize=8)
393
- ax.grid(True)
394
- fig.tight_layout()
395
- imgdata = io.BytesIO()
396
- fig.savefig(imgdata, format="png")
397
- imgs.append(imgdata.getvalue())
398
- plt.close()
529
+ if self.kind == "barh"
530
+ else df[c].plot.bar(title=f"{c}{title_suffix}", ax=ax)
531
+ )
532
+ ax.tick_params(axis="both", which="major", labelsize=8)
533
+ ax.grid(True)
534
+ pos += 1 # noqa: SIM113
535
+ fig.tight_layout()
536
+ imgdata = io.BytesIO()
537
+ fig.savefig(imgdata, format="png")
538
+ imgs.append(imgdata.getvalue())
539
+ plt.close()
399
540
  return imgs
400
541
 
401
- def to_charts(self, writer: pandas.ExcelWriter, sheet, empty_row: int = 1):
402
- """
403
- Draws plots on a page.
404
- The data is copied on this page.
405
-
406
- :param name: sheet name
407
- :param writer: writer (from pandas)
408
- :param sheet_name: sheet
409
- :param graph_index: graph index
410
- :return: list of graph
411
- """
412
- assert self.split, f"Not implemented if split={self.split}"
413
- assert self.orientation == "row", f"Not implemented if orientation={self.orientation}"
414
- workbook = writer.book
415
- labels = list(self.df.columns)
416
- sheet.write_row(empty_row, 0, labels)
417
-
418
- charts = []
419
- pos = empty_row + 1
420
- for i in self.df.index:
421
- values = self.df.loc[i, :].tolist()
422
- values = [("" if isinstance(v, float) and np.isnan(v) else v) for v in values]
423
- sheet.write_row(pos, 0, values)
424
- chart = workbook.add_chart({"type": "bar"})
425
- chart.add_series(
426
- {
427
- "name": i,
428
- "categories": [i, 1, empty_row, len(labels), empty_row],
429
- "values": [i, 1, pos, len(labels), pos],
430
- }
431
- )
432
- chart.set_title({"name": i})
433
- charts.append(chart)
434
- pos += 1
435
- return charts
542
+ def _to_images_line(
543
+ self, verbose: int = 0, merge: bool = True, title_suffix: Optional[str] = None
544
+ ) -> List[bytes]:
545
+ assert merge, f"merge={merge} not implemented yet"
546
+ assert (
547
+ self.orientation == "row"
548
+ ), f"self.orientation={self.orientation!r} not implemented for this kind of graph."
549
+
550
+ def rotate_align(ax, angle=15, align="right"):
551
+ for label in ax.get_xticklabels():
552
+ label.set_rotation(angle)
553
+ label.set_horizontalalignment(align)
554
+ ax.tick_params(axis="both", which="major", labelsize=8)
555
+ ax.grid(True)
556
+ ax.legend()
557
+ ax.tick_params(labelleft=True)
558
+ return ax
559
+
560
+ import matplotlib.pyplot as plt
561
+
562
+ df = self.df.T
563
+
564
+ confs = list(df.unstack(self.timeseries).index)
565
+ groups = self.group_columns(confs)
566
+ n_cols = len(groups)
567
+
568
+ title_suffix = f"\n{title_suffix}" if title_suffix else ""
569
+ fig, axs = plt.subplots(
570
+ df.shape[1],
571
+ n_cols,
572
+ figsize=(5 * n_cols, max(len(g) for g in groups) * df.shape[1] / 2),
573
+ sharex=True,
574
+ sharey="row" if n_cols > 1 else False,
575
+ )
576
+ imgs = []
577
+ row = 0
578
+ for c in self._make_loop(df.columns, verbose):
579
+ dfc = df[[c]]
580
+ dfc = dfc.unstack(self.timeseries).T.droplevel(0)
581
+ if n_cols == 1:
582
+ dfc.plot(title=f"{c}{title_suffix}", ax=axs[row], linewidth=3)
583
+ axs[row].grid(True)
584
+ rotate_align(axs[row])
585
+ else:
586
+ x = list(range(dfc.shape[0]))
587
+ ticks = list(dfc.index)
588
+ for ii, group in enumerate(groups):
589
+ ddd = dfc.loc[:, group].copy()
590
+ axs[row, ii].set_xticks(x)
591
+ axs[row, ii].set_xticklabels(ticks)
592
+ # This is very slow
593
+ # ddd.plot(ax=axs[row, ii],linewidth=3)
594
+ for jj in range(ddd.shape[1]):
595
+ axs[row, ii].plot(x, ddd.iloc[:, jj], lw=3, label=ddd.columns[jj])
596
+ axs[row, ii].set_title(f"{c}{title_suffix}")
597
+ rotate_align(axs[row, ii])
598
+ row += 1 # noqa: SIM113
599
+ fig.tight_layout()
600
+ imgdata = io.BytesIO()
601
+ fig.savefig(imgdata, format="png")
602
+ imgs.append(imgdata.getvalue())
603
+ plt.close()
604
+ return imgs
436
605
 
437
606
 
438
607
  class CubeLogs:
@@ -1084,7 +1253,11 @@ class CubeLogs:
1084
1253
  df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
1085
1254
 
1086
1255
  for name, view in views.items():
1256
+ if view is None:
1257
+ continue
1087
1258
  df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
1259
+ if tview is None:
1260
+ continue
1088
1261
  memory = df.memory_usage(deep=True).sum()
1089
1262
  if verbose:
1090
1263
  print(
@@ -1128,7 +1301,17 @@ class CubeLogs:
1128
1301
  )
1129
1302
  f_highlights[name] = tview.f_highlight
1130
1303
  if tview.plots:
1131
- plots.append(CubePlot(df, kind="barh", orientation="row", split=True))
1304
+ plots.append(
1305
+ CubePlot(
1306
+ df,
1307
+ kind="line",
1308
+ orientation="row",
1309
+ split=True,
1310
+ timeseries=self.time,
1311
+ )
1312
+ if self.time in df.columns.names
1313
+ else CubePlot(df, kind="barh", orientation="row", split=True)
1314
+ )
1132
1315
  if raw:
1133
1316
  assert main not in views, f"{main!r} is duplicated in views {sorted(views)}"
1134
1317
  # Too long.
@@ -1249,9 +1432,11 @@ class CubeLogsPerformance(CubeLogs):
1249
1432
  "n_node_scatter",
1250
1433
  "n_node_function",
1251
1434
  "n_node_initializer",
1435
+ "n_node_initializer_small",
1252
1436
  "n_node_constant",
1253
1437
  "n_node_shape",
1254
1438
  "n_node_expand",
1439
+ "onnx_n_nodes_no_cst",
1255
1440
  "peak_gpu_torch",
1256
1441
  "peak_gpu_nvidia",
1257
1442
  "time_export_unbiased",
@@ -1419,6 +1604,9 @@ class CubeLogsPerformance(CubeLogs):
1419
1604
  n_node_function=lambda df: gpreserve(
1420
1605
  df, "onnx_n_functions", gdf(df, "onnx_n_functions")
1421
1606
  ),
1607
+ n_node_initializer_small=lambda df: gpreserve(
1608
+ df, "op_onnx_initializer_small", gdf(df, "op_onnx_initializer_small")
1609
+ ),
1422
1610
  n_node_initializer=lambda df: gpreserve(
1423
1611
  df, "onnx_n_initializer", gdf(df, "onnx_n_initializer")
1424
1612
  ),
@@ -1437,6 +1625,10 @@ class CubeLogsPerformance(CubeLogs):
1437
1625
  ), f"Unexpected formula={formula!r}, should be in {sorted(lambdas)}"
1438
1626
  return lambdas[formula]
1439
1627
 
1628
+ if formula == "onnx_n_nodes_no_cst":
1629
+ return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
1630
+ df, "op_onnx__Constant", 0
1631
+ ).fillna(0)
1440
1632
  if formula == "peak_gpu_torch":
1441
1633
  return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
1442
1634
  if formula == "peak_gpu_nvidia":
@@ -1466,10 +1658,12 @@ class CubeLogsPerformance(CubeLogs):
1466
1658
 
1467
1659
  def view(
1468
1660
  self,
1469
- view_def: Union[str, CubeViewDef],
1661
+ view_def: Optional[Union[str, CubeViewDef]],
1470
1662
  return_view_def: bool = False,
1471
1663
  verbose: int = 0,
1472
- ) -> Union[pandas.DataFrame, Tuple[pandas.DataFrame, CubeViewDef]]:
1664
+ ) -> Union[
1665
+ Optional[pandas.DataFrame], Tuple[Optional[pandas.DataFrame], Optional[CubeViewDef]]
1666
+ ]:
1473
1667
  """
1474
1668
  Returns a dataframe, a pivot view.
1475
1669
 
@@ -1478,18 +1672,22 @@ class CubeLogsPerformance(CubeLogs):
1478
1672
  :param view_def: view definition or a string
1479
1673
  :param return_view_def: returns the view definition as well
1480
1674
  :param verbose: verbosity level
1481
- :return: dataframe
1675
+ :return: dataframe or a couple (dataframe, view definition),
1676
+ both of them can be one if view_def cannot be interpreted
1482
1677
  """
1678
+ assert view_def is not None, "view_def is None, this is not allowed."
1483
1679
  if isinstance(view_def, str):
1484
1680
  view_def = self.make_view_def(view_def)
1681
+ if view_def is None:
1682
+ return (None, None) if return_view_def else None
1485
1683
  return super().view(view_def, return_view_def=return_view_def, verbose=verbose)
1486
1684
 
1487
- def make_view_def(self, name: str) -> CubeViewDef:
1685
+ def make_view_def(self, name: str) -> Optional[CubeViewDef]:
1488
1686
  """
1489
1687
  Returns a view definition.
1490
1688
 
1491
1689
  :param name: name of the view
1492
- :return: a CubeViewDef
1690
+ :return: a CubeViewDef or None if name does not make sense
1493
1691
 
1494
1692
  Available views:
1495
1693
 
@@ -1588,6 +1786,8 @@ class CubeLogsPerformance(CubeLogs):
1588
1786
  "onnx_weight_size_torch",
1589
1787
  "onnx_weight_size_proto",
1590
1788
  "onnx_n_nodes",
1789
+ "onnx_n_nodes_no_cst",
1790
+ "op_onnx__Constant",
1591
1791
  "peak_gpu_torch",
1592
1792
  "peak_gpu_nvidia",
1593
1793
  ],
@@ -1617,6 +1817,7 @@ class CubeLogsPerformance(CubeLogs):
1617
1817
  "onnx_weight_size_torch",
1618
1818
  "onnx_weight_size_proto",
1619
1819
  "onnx_n_nodes",
1820
+ "onnx_n_nodes_no_cst",
1620
1821
  "peak_gpu_torch",
1621
1822
  "peak_gpu_nvidia",
1622
1823
  ],
@@ -1701,12 +1902,22 @@ class CubeLogsPerformance(CubeLogs):
1701
1902
  f_highlight=f_bucket,
1702
1903
  order=order,
1703
1904
  ),
1704
- "cmd": lambda: CubeViewDef(
1905
+ "onnx": lambda: CubeViewDef(
1705
1906
  key_index=index_cols,
1706
- values=self._filter_column(["CMD"], self.values),
1907
+ values=self._filter_column(
1908
+ [
1909
+ "onnx_filesize",
1910
+ "onnx_n_nodes",
1911
+ "onnx_n_nodes_no_cst",
1912
+ "onnx_weight_size_proto",
1913
+ "onnx_weight_size_torch",
1914
+ "op_onnx_initializer_small",
1915
+ ],
1916
+ self.values,
1917
+ ),
1707
1918
  ignore_unique=True,
1708
1919
  keep_columns_in_index=["suite"],
1709
- name="cmd",
1920
+ name="onnx",
1710
1921
  order=order,
1711
1922
  ),
1712
1923
  "raw-short": lambda: CubeViewDef(
@@ -1718,11 +1929,25 @@ class CubeLogsPerformance(CubeLogs):
1718
1929
  no_index=True,
1719
1930
  ),
1720
1931
  }
1721
- assert name in implemented_views, (
1932
+
1933
+ cmd_col = self._filter_column(["CMD"], self.values, can_be_empty=True)
1934
+ if cmd_col:
1935
+ implemented_views["cmd"] = lambda: CubeViewDef(
1936
+ key_index=index_cols,
1937
+ values=cmd_col,
1938
+ ignore_unique=True,
1939
+ keep_columns_in_index=["suite"],
1940
+ name="cmd",
1941
+ order=order,
1942
+ )
1943
+
1944
+ assert name in implemented_views or name in {"cmd"}, (
1722
1945
  f"Unknown view {name!r}, expected a name in {sorted(implemented_views)},"
1723
1946
  f"\n--\nkeys={pprint.pformat(sorted(self.keys_time))}, "
1724
1947
  f"\n--\nvalues={pprint.pformat(sorted(self.values))}"
1725
1948
  )
1949
+ if name not in implemented_views:
1950
+ return None
1726
1951
  return implemented_views[name]()
1727
1952
 
1728
1953
  def post_load_process_piece(
@@ -112,4 +112,14 @@ def make_feeds(
112
112
 
113
113
  if copy:
114
114
  flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
115
- return dict(zip(names, flat))
115
+ # bool, int, float, onnxruntime does not support float, bool, int
116
+ new_flat = []
117
+ for i in flat:
118
+ if isinstance(i, bool):
119
+ i = np.array(i, dtype=np.bool_)
120
+ elif isinstance(i, int):
121
+ i = np.array(i, dtype=np.int64)
122
+ elif isinstance(i, float):
123
+ i = np.array(i, dtype=np.float32)
124
+ new_flat.append(i)
125
+ return dict(zip(names, new_flat))
@@ -717,7 +717,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
717
717
  return tuple(to_any(t, to_value) for t in value)
718
718
  if isinstance(value, set):
719
719
  return {to_any(t, to_value) for t in value}
720
- if isinstance(value, dict):
720
+ if type(value) is dict:
721
721
  return {k: to_any(t, to_value) for k, t in value.items()}
722
722
  if value.__class__.__name__ == "DynamicCache":
723
723
  return make_dynamic_cache(
@@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
735
735
  [t.to(to_value) for t in value.key_cache],
736
736
  [t.to(to_value) for t in value.value_cache],
737
737
  )
738
- )
738
+ ),
739
+ max_cache_len=value.max_cache_len,
739
740
  )
740
741
  if value.__class__.__name__ == "EncoderDecoderCache":
741
742
  return make_encoder_decoder_cache(
@@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any:
784
785
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
785
786
  )
786
787
  if value.__class__.__name__ == "StaticCache":
787
- return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
788
+ return make_static_cache(
789
+ torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
790
+ max_cache_len=value.max_cache_len,
791
+ )
788
792
  if value.__class__.__name__ == "SlidingWindowCache":
789
793
  return make_sliding_window_cache(
790
794
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
@@ -11,6 +11,7 @@ from . import (
11
11
  summarization,
12
12
  text_classification,
13
13
  text_generation,
14
+ text_to_image,
14
15
  text2text_generation,
15
16
  zero_shot_image_classification,
16
17
  )
@@ -27,6 +28,7 @@ __TASKS__ = [
27
28
  summarization,
28
29
  text_classification,
29
30
  text_generation,
31
+ text_to_image,
30
32
  text2text_generation,
31
33
  zero_shot_image_classification,
32
34
  ]
@@ -109,7 +109,7 @@ def get_inputs(
109
109
  sequence_length2 = seq_length_multiple
110
110
 
111
111
  shapes = {
112
- "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
112
+ "input_ids": {0: batch, 1: "sequence_length"},
113
113
  "attention_mask": {
114
114
  0: batch,
115
115
  1: "cache+seq", # cache_length + seq_length
@@ -176,8 +176,10 @@ def get_inputs(
176
176
  "attention_mask": {0: batch, 2: "seq"},
177
177
  "cache_position": {0: "seq"},
178
178
  "past_key_values": [
179
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
179
+ # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
+ # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181
+ [{0: batch} for _ in range(num_hidden_layers)],
182
+ [{0: batch} for _ in range(num_hidden_layers)],
181
183
  ],
182
184
  }
183
185
  inputs = dict(
@@ -188,18 +190,25 @@ def get_inputs(
188
190
  (batch_size, num_key_value_heads, sequence_length2, head_dim)
189
191
  ).to(torch.bool),
190
192
  cache_position=torch.arange(sequence_length2).to(torch.int64),
191
- past_key_values=make_cache(
193
+ past_key_values=make_static_cache(
192
194
  [
193
195
  (
194
196
  torch.randn(
195
- batch_size, num_key_value_heads, sequence_length, head_dim
197
+ batch_size,
198
+ num_key_value_heads,
199
+ sequence_length + sequence_length2,
200
+ head_dim,
196
201
  ),
197
202
  torch.randn(
198
- batch_size, num_key_value_heads, sequence_length, head_dim
203
+ batch_size,
204
+ num_key_value_heads,
205
+ sequence_length + sequence_length2,
206
+ head_dim,
199
207
  ),
200
208
  )
201
209
  for i in range(num_hidden_layers)
202
- ]
210
+ ],
211
+ max_cache_len=max(sequence_length + sequence_length2, head_dim),
203
212
  ),
204
213
  )
205
214
  else:
@@ -230,7 +239,7 @@ def get_inputs(
230
239
  position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
231
240
  .to(torch.int64)
232
241
  .expand((batch_size, -1)),
233
- past_key_values=make_cache(
242
+ past_key_values=make_cache( # type: ignore[operator]
234
243
  [
235
244
  (
236
245
  torch.randn(