onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  5. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  6. onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
  7. onnx_diagnostic/export/api.py +13 -4
  8. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  9. onnx_diagnostic/export/validate.py +2 -0
  10. onnx_diagnostic/ext_test_case.py +32 -15
  11. onnx_diagnostic/helpers/args_helper.py +1 -0
  12. onnx_diagnostic/helpers/bench_run.py +0 -1
  13. onnx_diagnostic/helpers/cache_helper.py +102 -36
  14. onnx_diagnostic/helpers/doc_helper.py +7 -4
  15. onnx_diagnostic/helpers/graph_helper.py +6 -6
  16. onnx_diagnostic/helpers/helper.py +39 -0
  17. onnx_diagnostic/helpers/log_helper.py +37 -14
  18. onnx_diagnostic/helpers/memory_peak.py +5 -1
  19. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  20. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  21. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  22. onnx_diagnostic/helpers/ort_session.py +5 -2
  23. onnx_diagnostic/helpers/rt_helper.py +53 -9
  24. onnx_diagnostic/helpers/torch_helper.py +15 -11
  25. onnx_diagnostic/investigate/__init__.py +0 -0
  26. onnx_diagnostic/investigate/input_observer.py +970 -0
  27. onnx_diagnostic/reference/evaluator.py +0 -1
  28. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  29. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  30. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  31. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  32. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  33. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  34. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  35. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  36. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  39. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  40. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  41. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
  42. onnx_diagnostic/torch_models/code_sample.py +5 -10
  43. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  44. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  45. onnx_diagnostic/torch_models/validate.py +1 -1
  46. onnx_diagnostic/torch_onnx/compare.py +0 -1
  47. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  48. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  49. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  50. onnx_diagnostic/typing.py +15 -0
  51. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  52. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
  53. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  54. onnx_diagnostic/api.py +0 -15
  55. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,10 @@ class CubeViewDef:
29
29
  :param order: to reorder key in columns index
30
30
  :param key_agg: aggregate according to these columns before
31
31
  creating the view
32
- :param agg_args: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`,
32
+ :param agg_args: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`,
33
33
  it can be also a callable to return a different aggregation
34
34
  method depending on the column name
35
- :param agg_kwargs: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
35
+ :param agg_kwargs: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`
36
36
  :param agg_multi: aggregation over multiple columns
37
37
  :param ignore_columns: ignore the following columns if known to overload the view
38
38
  :param keep_columns_in_index: keeps the columns even if there is only one unique value
@@ -98,7 +98,7 @@ class CubeViewDef:
98
98
  agg_args: Union[Sequence[Any], Callable[[str], Any]] = ("sum",),
99
99
  agg_kwargs: Optional[Dict[str, Any]] = None,
100
100
  agg_multi: Optional[
101
- Dict[str, Callable[[pandas.core.groupby.DataFrameGroupBy], pandas.Series]]
101
+ Dict[str, Callable[[pandas.api.typing.DataFrameGroupBy], pandas.Series]]
102
102
  ] = None,
103
103
  ignore_columns: Optional[Sequence[str]] = None,
104
104
  keep_columns_in_index: Optional[Sequence[str]] = None,
@@ -365,6 +365,7 @@ class CubePlot:
365
365
  # This is very slow
366
366
  # ddd.plot(ax=axs[row, ii],linewidth=3)
367
367
  for jj in range(ddd.shape[1]):
368
+ # pyrefly: ignore[bad-index]
368
369
  axs[row, ii].plot(x, ddd.iloc[:, jj], lw=3, label=ddd.columns[jj])
369
370
  axs[row, ii].set_title(f"{c}{title_suffix}")
370
371
  rotate_align(axs[row, ii])
@@ -480,7 +481,9 @@ class CubeLogs:
480
481
  elif isinstance(self._data, list) and all(isinstance(r, dict) for r in self._data):
481
482
  if verbose:
482
483
  print(f"[CubeLogs.load] load from list of dicts, n={len(self._data)}")
483
- self.data = pandas.DataFrame(self.post_load_process_piece(self._data, unique=True))
484
+ self.data = pandas.DataFrame(
485
+ self.post_load_process_piece(pandas.DataFrame(self._data), unique=True)
486
+ )
484
487
  if verbose:
485
488
  print(f"[CubeLogs.load] after postprocessing shape={self.data.shape}")
486
489
  elif isinstance(self._data, list) and all(
@@ -614,7 +617,7 @@ class CubeLogs:
614
617
 
615
618
  def _process_formula(
616
619
  self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
617
- ) -> Callable[[pandas.DataFrame], pandas.Series]:
620
+ ) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
618
621
  assert callable(formula), f"formula={formula!r} is not supported."
619
622
  return formula
620
623
 
@@ -625,9 +628,11 @@ class CubeLogs:
625
628
  return self.data.shape
626
629
 
627
630
  @property
628
- def columns(self) -> Sequence[str]:
631
+ def columns(self) -> Sequence[Any]:
629
632
  "Returns the columns."
630
633
  assert hasattr(self, "data"), "Method load was not called"
634
+ assert isinstance(self.data, pandas.DataFrame) # type checking
635
+ # pyrefly: ignore[bad-return]
631
636
  return self.data.columns
632
637
 
633
638
  def _preprocess(self):
@@ -647,7 +652,7 @@ class CubeLogs:
647
652
  )
648
653
  assert gr.shape[0] > 0, (
649
654
  f"Something went wrong after the groupby.\n"
650
- f"{cp[[*self.keys, self.time, '__index__']].head().T}"
655
+ f"{cp[[*self.keys_no_time, self.time, '__index__']].head().T}"
651
656
  )
652
657
  filtered = pandas.merge(cp, gr, on=["__index__", *self.keys_time])
653
658
  assert filtered.shape[0] <= self.data.shape[0], (
@@ -797,6 +802,7 @@ class CubeLogs:
797
802
  if view_def.agg_multi:
798
803
  append = []
799
804
  for k, f in view_def.agg_multi.items():
805
+ # pyrefly: ignore[no-matching-overload]
800
806
  cv = grouped_data.apply(f, include_groups=False)
801
807
  append.append(cv.to_frame(k))
802
808
  data = pandas.concat([data, *append], axis=1)
@@ -1020,8 +1026,10 @@ class CubeLogs:
1020
1026
 
1021
1027
  keys = set(self.keys_no_time) - {columns_to_fix}
1022
1028
  select = data[self.keys_no_time]
1029
+ # pyrefly: ignore[no-matching-overload]
1023
1030
  select_agg = select.groupby(list(keys), as_index=True).apply(
1024
- lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))), include_groups=False
1031
+ lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))),
1032
+ include_groups=False,
1025
1033
  )
1026
1034
  select_agg = select_agg.to_frame(name=columns_to_fix)
1027
1035
  res = pandas.merge(
@@ -1137,6 +1145,7 @@ class CubeLogs:
1137
1145
  if len(nonan) > 0:
1138
1146
  obs.update(dict(count=len(nonan)))
1139
1147
  if is_numeric_dtype(nonan) and not pandas.api.types.is_object_dtype(nonan):
1148
+ # pyrefly: ignore[no-matching-overload]
1140
1149
  obs.update(
1141
1150
  dict(
1142
1151
  min=nonan.min(),
@@ -1208,12 +1217,15 @@ class CubeLogs:
1208
1217
  df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
1209
1218
 
1210
1219
  time_mask_view: Dict[str, pandas.DataFrame] = {}
1220
+ df = None
1211
1221
  for name, view in views.items():
1212
1222
  if view is None:
1213
1223
  continue
1214
1224
  df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
1215
1225
  if cube_time is not None:
1216
1226
  cube_mask = cube_time.view(view)
1227
+ assert isinstance(cube_mask, pandas.DataFrame) # type checking
1228
+ assert isinstance(df, pandas.DataFrame) # type checking
1217
1229
  aligned = align_dataframe_with(cube_mask, df)
1218
1230
  if aligned is not None:
1219
1231
  assert aligned.shape == df.shape, (
@@ -1228,6 +1240,7 @@ class CubeLogs:
1228
1240
  )
1229
1241
  if tview is None:
1230
1242
  continue
1243
+ assert isinstance(df, pandas.DataFrame) # type checking
1231
1244
  memory = df.memory_usage(deep=True).sum()
1232
1245
  if verbose:
1233
1246
  print(
@@ -1269,7 +1282,9 @@ class CubeLogs:
1269
1282
  sheet_name=name,
1270
1283
  freeze_panes=(df.columns.nlevels + 1, df.index.nlevels),
1271
1284
  )
1285
+ # pyrefly: ignore[missing-attribute]
1272
1286
  f_highlights[name] = tview.f_highlight
1287
+ # pyrefly: ignore[missing-attribute]
1273
1288
  if tview.plots:
1274
1289
  plots.append(
1275
1290
  CubePlot(
@@ -1282,6 +1297,7 @@ class CubeLogs:
1282
1297
  if self.time in df.columns.names
1283
1298
  else CubePlot(df, kind="barh", orientation="row", split=True)
1284
1299
  )
1300
+ assert isinstance(df, pandas.DataFrame) # type checking
1285
1301
  if raw:
1286
1302
  assert main not in views, f"{main!r} is duplicated in views {sorted(views)}"
1287
1303
  # Too long.
@@ -1439,7 +1455,7 @@ class CubeLogs:
1439
1455
  len(configs) >= 2
1440
1456
  ), f"A side by side needs at least two configs but configs={configs}"
1441
1457
  set_keys_time = set(self.keys_time)
1442
- columns_index = None
1458
+ columns_index: Optional[List[str]] = None
1443
1459
  data_list = []
1444
1460
  for name_conf, conf in configs.items():
1445
1461
  if columns_index is None:
@@ -1478,9 +1494,11 @@ class CubeLogs:
1478
1494
 
1479
1495
  # add metrics
1480
1496
  index_column_name = list(view_res.columns.names).index(column_name)
1497
+ # pyrefly: ignore[missing-attribute]
1481
1498
  index_metrics = list(view_res.columns.names).index("METRICS")
1482
1499
 
1483
1500
  def _mkc(m, s):
1501
+ # pyrefly: ignore[missing-attribute]
1484
1502
  c = ["" for c in view_res.columns.names]
1485
1503
  c[index_column_name] = s
1486
1504
  c[index_metrics] = m
@@ -1515,7 +1533,9 @@ class CubeLogs:
1515
1533
  ci["CONF"] = iname
1516
1534
  cj["CONF"] = jname
1517
1535
 
1536
+ # pyrefly: ignore[bad-index]
1518
1537
  ci_name = tuple(ci[n] for n in view_res.columns.names)
1538
+ # pyrefly: ignore[bad-index]
1519
1539
  cj_name = tuple(cj[n] for n in view_res.columns.names)
1520
1540
  assert ci_name in view_res.columns or cj_name in view_res.columns, (
1521
1541
  f"Unable to find column {ci_name} or {cj_name} "
@@ -1562,6 +1582,7 @@ class CubeLogs:
1562
1582
  }
1563
1583
  flat = view_res.groupby(self.time).agg(aggs)
1564
1584
  flat = flat.stack("METRICS", future_stack=True)
1585
+ # pyrefly: ignore[bad-return, missing-attribute]
1565
1586
  return res, flat, view_res.T.sort_index().T
1566
1587
 
1567
1588
 
@@ -1679,7 +1700,7 @@ class CubeLogsPerformance(CubeLogs):
1679
1700
 
1680
1701
  def _process_formula(
1681
1702
  self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
1682
- ) -> Callable[[pandas.DataFrame], pandas.Series]:
1703
+ ) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
1683
1704
  """
1684
1705
  Processes a formula, converting it into a function.
1685
1706
 
@@ -1726,6 +1747,7 @@ class CubeLogsPerformance(CubeLogs):
1726
1747
  f"{pprint.pformat(sorted(columns))}"
1727
1748
  )
1728
1749
  # return lambda df: df["time_latency_eager"] / df["time_latency"]
1750
+ # pyrefly: ignore[no-matching-overload]
1729
1751
  return lambda df: pandas.cut(
1730
1752
  df["speedup"], bins=BUCKET_SCALES, right=False, duplicates="raise"
1731
1753
  )
@@ -1733,9 +1755,9 @@ class CubeLogsPerformance(CubeLogs):
1733
1755
  if formula == "ERR1":
1734
1756
  columns = set(self._filter_column(["^ERR_.*"], self.data.columns))
1735
1757
  if not columns:
1736
- return lambda df: np.nan
1758
+ return lambda df: None
1737
1759
 
1738
- def first_err(df: pandas.DataFrame) -> pandas.Series:
1760
+ def first_err(df: pandas.DataFrame) -> Optional[pandas.Series]:
1739
1761
  ordered = [
1740
1762
  c
1741
1763
  for c in [
@@ -1752,7 +1774,7 @@ class CubeLogsPerformance(CubeLogs):
1752
1774
  ]
1753
1775
  if c in df.columns
1754
1776
  ]
1755
- res = None
1777
+ res: Optional[pandas.Series] = None
1756
1778
  for c in ordered:
1757
1779
  if res is None:
1758
1780
  res = df[c].fillna("")
@@ -1949,6 +1971,7 @@ class CubeLogsPerformance(CubeLogs):
1949
1971
  f"{pprint.pformat(sorted(self.data.columns))}"
1950
1972
  )
1951
1973
 
1974
+ # pyrefly: ignore[bad-override]
1952
1975
  def view(
1953
1976
  self,
1954
1977
  view_def: Optional[Union[str, CubeViewDef]],
@@ -2265,7 +2288,7 @@ class CubeLogsPerformance(CubeLogs):
2265
2288
  if unique:
2266
2289
  return df
2267
2290
  cols = self._filter_column(self._keys, df)
2268
- res = None
2291
+ res: Optional[pandas.DataFrame] = None
2269
2292
  for c in cols:
2270
2293
  if df[c].isna().any():
2271
2294
  # Missing values for keys are not supposed to happen.
@@ -103,6 +103,7 @@ def _process_memory_spy(conn):
103
103
  process = psutil.Process(pid)
104
104
 
105
105
  if cuda:
106
+ # pyrefly: ignore[missing-import]
106
107
  from pynvml import (
107
108
  nvmlDeviceGetCount,
108
109
  nvmlDeviceGetHandleByIndex,
@@ -131,6 +132,7 @@ def _process_memory_spy(conn):
131
132
  mem = process.memory_info().rss
132
133
  cpu.update(mem)
133
134
  if cuda:
135
+ # pyrefly: ignore[unbound-name]
134
136
  for r, g in zip(gpu_used(), gpus):
135
137
  g.update(r)
136
138
  if conn.poll(timeout=timeout):
@@ -142,6 +144,7 @@ def _process_memory_spy(conn):
142
144
  end = process.memory_info().rss
143
145
  cpu.update(end)
144
146
  if cuda:
147
+ # pyrefly: ignore[unbound-name]
145
148
  for r, g in zip(gpu_used(), gpus):
146
149
  g.update(r)
147
150
 
@@ -151,6 +154,7 @@ def _process_memory_spy(conn):
151
154
  for g in gpus:
152
155
  g.send(conn)
153
156
  if cuda:
157
+ # pyrefly: ignore[unbound-name]
154
158
  nvmlShutdown()
155
159
  conn.close()
156
160
 
@@ -217,7 +221,7 @@ def start_spying_on(
217
221
  Starts the memory spy. The function starts another
218
222
  process spying on the one sent as an argument.
219
223
 
220
- :param pid: process id to spy or the the current one.
224
+ :param pid: process id to spy or the current one.
221
225
  :param delay: delay between two measures.
222
226
  :param cuda: True or False to get memory for cuda devices
223
227
 
@@ -8,11 +8,6 @@ import torch
8
8
  from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended
9
9
  from . import string_type
10
10
 
11
- STORAGE_TYPE = {
12
- TensorProto.FLOAT16: np.int16,
13
- TensorProto.BFLOAT16: np.int16,
14
- }
15
-
16
11
 
17
12
  def proto_from_array(
18
13
  arr: torch.Tensor,
@@ -67,13 +62,13 @@ def proto_from_array(
67
62
  byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
68
63
  tensor.raw_data = bytes(byte_data)
69
64
  if sys.byteorder == "big":
70
- np_dtype = tensor_dtype_to_np_dtype(STORAGE_TYPE[tensor.data_type])
71
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
65
+ np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
66
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
72
67
  else:
73
68
  tensor.raw_data = np_arr.tobytes()
74
69
  if sys.byteorder == "big":
75
70
  np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
76
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
71
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
77
72
 
78
73
  return tensor
79
74
 
@@ -133,6 +128,7 @@ class MiniOnnxBuilder:
133
128
  }
134
129
  shape = tuple(map(int, tensor.shape))
135
130
  self.nodes.append(
131
+ # pyrefly: ignore[bad-argument-type]
136
132
  oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs)
137
133
  )
138
134
  self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
@@ -632,6 +628,7 @@ def create_input_tensors_from_onnx_model(
632
628
  raise AssertionError(f"Unexpected value for engine={engine!r}")
633
629
 
634
630
  got = sess.run(None, {})
631
+ assert isinstance(got, list) # type checking
635
632
  if len(names) == 1:
636
633
  name = names[0]
637
634
  output = got[0]
@@ -639,12 +636,10 @@ def create_input_tensors_from_onnx_model(
639
636
  return None
640
637
  if name == "array":
641
638
  return output
642
- if name == "bool":
643
- return bool(output[0])
644
- if name == "int":
645
- return int(output[0])
646
- if name == "float":
647
- return float(output[0])
639
+ if name in {"bool", "int", "float"}:
640
+ cvt = {"bool": bool, "int": int, "float": float}[name]
641
+ # pyrefly: ignore[bad-index]
642
+ return cvt(output[0])
648
643
  if name == "tensor":
649
644
  return torch.from_numpy(output).to(device)
650
645
  assert name.startswith(
@@ -14,7 +14,7 @@ CACHE_SUBDIR = "onnx-diagnostic"
14
14
 
15
15
  def download_model_builder_to_cache(
16
16
  url: str = "https://raw.githubusercontent.com/microsoft/onnxruntime-genai/refs/heads/main/src/python/py/models/builder.py",
17
- ):
17
+ ) -> Path:
18
18
  """
19
19
  Downloads ``builder.py`` from the
20
20
  ``https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py``.