onnx-diagnostic 0.8.9__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
  4. onnx_diagnostic/export/api.py +24 -12
  5. onnx_diagnostic/export/validate.py +2 -0
  6. onnx_diagnostic/ext_test_case.py +32 -15
  7. onnx_diagnostic/helpers/args_helper.py +1 -0
  8. onnx_diagnostic/helpers/bench_run.py +0 -1
  9. onnx_diagnostic/helpers/cache_helper.py +6 -6
  10. onnx_diagnostic/helpers/doc_helper.py +7 -4
  11. onnx_diagnostic/helpers/graph_helper.py +6 -6
  12. onnx_diagnostic/helpers/log_helper.py +37 -14
  13. onnx_diagnostic/helpers/memory_peak.py +5 -1
  14. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  15. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  16. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  17. onnx_diagnostic/helpers/ort_session.py +0 -1
  18. onnx_diagnostic/helpers/torch_helper.py +8 -9
  19. onnx_diagnostic/investigate/__init__.py +0 -0
  20. onnx_diagnostic/investigate/input_observer.py +329 -0
  21. onnx_diagnostic/reference/evaluator.py +0 -1
  22. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  23. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  24. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  25. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  26. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  27. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  28. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  29. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +5 -1
  30. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  32. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +14 -13
  33. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
  34. onnx_diagnostic/torch_models/code_sample.py +5 -10
  35. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +7 -12
  37. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -1
  39. onnx_diagnostic/torch_onnx/compare.py +0 -1
  40. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  41. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  42. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  43. onnx_diagnostic/typing.py +15 -0
  44. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
  45. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +48 -46
  46. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
  47. onnx_diagnostic/api.py +0 -15
  48. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
  49. {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
@@ -445,10 +445,6 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
445
445
  and not isinstance(v, (bool, int, float))
446
446
  }
447
447
  )
448
- if self.expand_batch_for:
449
- # extends the inputs to artificially create a batch dimension != 1.
450
- inp_args = self._expand_batch_dimension(inp_args, self.expand_batch_for)
451
- inp_kwargs = self._expand_batch_dimension(inp_kwargs, self.expand_batch_for)
452
448
  inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
453
449
  # reorders the parameter following the method signature.
454
450
  inp_kwargs = self._reorder_kwargs(inp_kwargs)
@@ -513,12 +509,10 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
513
509
  simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
514
510
  args = str(simple_sig)[1:-1]
515
511
  calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
516
- src = textwrap.dedent(
517
- f"""
512
+ src = textwrap.dedent(f"""
518
513
  def f(self, {args}):
519
514
  return self._method_call({calls_args})
520
- """
521
- )
515
+ """)
522
516
  self._method_src = src
523
517
  ns = {}
524
518
  try:
@@ -557,6 +551,10 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
557
551
  else:
558
552
  a, kw = self._inputs[-1]
559
553
  nds = [self.dynamic_shapes]
554
+ if self.expand_batch_for:
555
+ # extends the inputs to artificially create a batch dimension != 1.
556
+ a = self._expand_batch_dimension(a, self.expand_batch_for)
557
+ kw = self._expand_batch_dimension(kw, self.expand_batch_for)
560
558
  if self.verbose:
561
559
  print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
562
560
  print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
@@ -738,7 +736,9 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
738
736
  :param verbose: verbosity
739
737
  :return: results, a list of dictionaries, ready to be consumed by a dataframe
740
738
  """
741
- assert self._export_done, "The onnx export was not done."
739
+ assert (
740
+ self._export_done
741
+ ), f"The onnx export was not done, only {len(self._inputs)} were stored."
742
742
  assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
743
743
  assert os.path.exists(
744
744
  self._output_file
@@ -750,17 +750,29 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
750
750
  classes = [
751
751
  cls
752
752
  for cls in self._serialization_classes
753
- if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
753
+ if cls
754
+ not in {
755
+ int,
756
+ float,
757
+ bool,
758
+ str,
759
+ torch.Tensor,
760
+ list,
761
+ set,
762
+ dict,
763
+ torch.device,
764
+ torch.dtype,
765
+ }
754
766
  ]
755
767
  if verbose:
756
768
  print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
757
769
  print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
758
770
  with torch.serialization.safe_globals(classes):
759
- inputs = torch.load(self._input_file)
771
+ inputs = torch.load(self._input_file, weights_only=False)
760
772
  if verbose:
761
773
  print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
762
774
  with torch.serialization.safe_globals(classes):
763
- outputs = torch.load(self._output_file)
775
+ outputs = torch.load(self._output_file, weights_only=False)
764
776
  assert len(inputs) == len(outputs), (
765
777
  f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
766
778
  f"inputs={string_type(inputs, with_shape=True)}, "
@@ -80,6 +80,7 @@ def compare_modules(
80
80
  )
81
81
  got = modep(*_get(args), **_get(kwargs))
82
82
  if verbose:
83
+ # pyrefly: ignore[unbound-name]
83
84
  d = time.perf_counter() - begin
84
85
  print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85
86
  if mod:
@@ -89,6 +90,7 @@ def compare_modules(
89
90
  expected = mod(*_get(args), **_get(kwargs))
90
91
  diff = max_diff(expected, got)
91
92
  if verbose:
93
+ # pyrefly: ignore[unbound-name]
92
94
  d = time.perf_counter() - begin
93
95
  print(
94
96
  f"[compare_modules] done in {d} with "
@@ -780,7 +780,7 @@ class ExtTestCase(unittest.TestCase):
780
780
 
781
781
  @property
782
782
  def verbose(self) -> int:
783
- "Returns the the value of environment variable ``VERBOSE``."
783
+ "Returns the value of environment variable ``VERBOSE``."
784
784
  return int(os.environ.get("VERBOSE", "0"))
785
785
 
786
786
  @classmethod
@@ -1028,6 +1028,19 @@ class ExtTestCase(unittest.TestCase):
1028
1028
  rtol=rtol,
1029
1029
  msg=msg,
1030
1030
  )
1031
+ elif expected.__class__.__name__ == "BaseModelOutputWithPooling":
1032
+ if expected.__class__.__name__ == value.__class__.__name__:
1033
+ self.assertEqual(len(expected), len(value), msg=msg)
1034
+ self.assertEqual(list(expected), list(value), msg=msg) # checks the order
1035
+ self.assertEqualAny(
1036
+ {k: v for k, v in expected.items()}, # noqa: C416
1037
+ {k: v for k, v in value.items()}, # noqa: C416
1038
+ atol=atol,
1039
+ rtol=rtol,
1040
+ msg=msg,
1041
+ )
1042
+ else:
1043
+ self.assertEqualArray(expected.last_hidden_state, value)
1031
1044
  elif isinstance(expected, (tuple, list, dict)):
1032
1045
  self.assertIsInstance(value, type(expected), msg=msg)
1033
1046
  self.assertEqual(len(expected), len(value), msg=msg)
@@ -1043,24 +1056,28 @@ class ExtTestCase(unittest.TestCase):
1043
1056
  "SlidingWindowCache",
1044
1057
  "HybridCache",
1045
1058
  ):
1059
+ from .helpers.cache_helper import CacheKeyValue
1060
+
1046
1061
  self.assertEqual(type(expected), type(value), msg=msg)
1047
- atts = ["key_cache", "value_cache"]
1048
- self.assertEqualAny(
1049
- {k: expected.__dict__.get(k, None) for k in atts},
1050
- {k: value.__dict__.get(k, None) for k in atts},
1051
- atol=atol,
1052
- rtol=rtol,
1053
- )
1062
+ self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
1054
1063
  elif expected.__class__.__name__ == "StaticCache":
1064
+ from .helpers.cache_helper import CacheKeyValue
1065
+
1055
1066
  self.assertEqual(type(expected), type(value), msg=msg)
1056
1067
  self.assertEqual(expected.max_cache_len, value.max_cache_len)
1057
- atts = ["key_cache", "value_cache"]
1058
- self.assertEqualAny(
1059
- {k: expected.__dict__.get(k, None) for k in atts},
1060
- {k: value.__dict__.get(k, None) for k in atts},
1061
- atol=atol,
1062
- rtol=rtol,
1063
- )
1068
+ self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
1069
+ elif expected.__class__.__name__ == "CacheKeyValue":
1070
+ self.assertEqual(type(expected), type(value), msg=msg)
1071
+ if expected.cls_layers is None:
1072
+ self.assertEqual(expected.cls_layers, value.cls_layers)
1073
+ else:
1074
+ self.assertEqualAny(
1075
+ [cls.__name__ for cls in expected.cls_layers],
1076
+ [cls.__name__ for cls in value.cls_layers],
1077
+ msg=msg,
1078
+ )
1079
+ self.assertEqualAny(expected.key_cache, value.key_cache, msg=msg)
1080
+ self.assertEqualAny(expected.value_cache, value.value_cache, msg=msg)
1064
1081
  elif expected.__class__.__name__ == "EncoderDecoderCache":
1065
1082
  self.assertEqual(type(expected), type(value), msg=msg)
1066
1083
  atts = ["self_attention_cache", "cross_attention_cache"]
@@ -105,6 +105,7 @@ def get_parsed_args(
105
105
  default=tries,
106
106
  )
107
107
  for k, v in kwargs.items():
108
+ assert isinstance(v, tuple) # type
108
109
  parser.add_argument(
109
110
  f"--{k}",
110
111
  help=f"{v[1]}, default is {v[0]}",
@@ -11,7 +11,6 @@ from argparse import Namespace
11
11
  from datetime import datetime
12
12
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13
13
 
14
-
15
14
  _DEFAULT_STRING_LIMIT = 2000
16
15
 
17
16
 
@@ -90,7 +90,7 @@ def flatten_unflatten_for_dynamic_shapes(
90
90
  the context gives the dictionary keys but it is not expressed
91
91
  in the dynamic shapes, these specifications seems to be different
92
92
  for the strict and non strict mode. It also preserves tuple.
93
- :param change_function: to modifies the tensor in the structure itself,
93
+ :param change_function: to modify the tensor in the structure itself,
94
94
  like replace them by a shape
95
95
  :return: the serialized object
96
96
  """
@@ -110,7 +110,7 @@ def flatten_unflatten_for_dynamic_shapes(
110
110
  start = end
111
111
  if use_dict:
112
112
  if spec.type is dict:
113
- # This a dictionary.
113
+ # This is a dictionary.
114
114
  return dict(zip(spec.context, subtrees))
115
115
  if spec.type is tuple:
116
116
  return tuple(subtrees)
@@ -348,6 +348,7 @@ else:
348
348
  def make_static_cache(
349
349
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
350
350
  max_cache_len: Optional[int] = None,
351
+ cls_layers: Optional[Union[str, List[type]]] = None,
351
352
  ) -> transformers.cache_utils.DynamicCache:
352
353
  """
353
354
  Creates an instance of :class:`transformers.cache_utils.StaticCache`.
@@ -379,6 +380,9 @@ def make_static_cache(
379
380
  )
380
381
  print(string_type(past_key_values, with_shape=True))
381
382
  """
383
+ assert not cls_layers or set(cls_layers) == {
384
+ transformers.cache_utils.StaticLayer
385
+ }, f"Not implemented when cls_layers={cls_layers!r}"
382
386
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
383
387
 
384
388
  class _config:
@@ -583,13 +587,9 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
583
587
  )
584
588
  return finalize_cache(cache)
585
589
 
586
- def get_make_hybrid_cache():
587
- return make_sliding_window_cache
588
-
589
590
  else:
590
591
  make_sliding_window_cache = None # type: ignore[assignment]
591
592
 
592
-
593
593
  if hasattr(transformers.cache_utils, "HybridCache"):
594
594
 
595
595
  def make_hybrid_cache(
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Dict, List, Optional, Tuple
2
+ from typing import Any, Dict, List, Optional, Tuple
3
3
  import onnx
4
4
  import onnx.helper as oh
5
5
  import torch
@@ -46,10 +46,10 @@ class LayerNormalizationOrt(OpRunKernel):
46
46
  f"This kernel implementation only work when only one output "
47
47
  f"is required but {node.output} were."
48
48
  )
49
- self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
49
+ self._cache: Dict[Tuple[int, int], Any] = {}
50
50
  self.is_cpu = torch.device("cpu") == self.device
51
51
 
52
- def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
52
+ def _make_model(self, itype: int, rank: int, has_bias: bool) -> Any:
53
53
  shape = [*["d{i}" for i in range(rank - 1)], "last"]
54
54
  layer_model = oh.make_model(
55
55
  oh.make_graph(
@@ -88,6 +88,7 @@ class LayerNormalizationOrt(OpRunKernel):
88
88
  providers=[provider],
89
89
  )
90
90
 
91
+ # pyrefly: ignore[bad-override]
91
92
  def run(self, x, scale, bias=None):
92
93
  itype = torch_dtype_to_onnx_dtype(x.dtype)
93
94
  rank = len(x.shape)
@@ -124,7 +125,7 @@ class MatMulOrt(OpRunKernel):
124
125
  self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
125
126
  self.is_cpu = torch.device("cpu") == self.device
126
127
 
127
- def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
128
+ def _make_model(self, itype: int, ranka: int, rankb: int) -> Any:
128
129
  shapea = ["a{i}" for i in range(ranka)]
129
130
  shapeb = ["b{i}" for i in range(rankb)]
130
131
  shapec = ["c{i}" for i in range(max(ranka, rankb))]
@@ -149,6 +150,7 @@ class MatMulOrt(OpRunKernel):
149
150
  providers=[provider],
150
151
  )
151
152
 
153
+ # pyrefly: ignore[bad-override]
152
154
  def run(self, a, b):
153
155
  itype = torch_dtype_to_onnx_dtype(a.dtype)
154
156
  ranka, rankb = len(a.shape), len(b.shape)
@@ -159,5 +161,6 @@ class MatMulOrt(OpRunKernel):
159
161
  if self.verbose:
160
162
  print(f"[MatMulOrt] running on {self._provider!r}")
161
163
  feeds = dict(A=a.tensor, B=b.tensor)
164
+ # pyrefly: ignore[missing-attribute]
162
165
  got = sess.run(None, feeds)[0]
163
166
  return OpRunTensor(got)
@@ -36,7 +36,7 @@ class GraphRendering:
36
36
  :return: computation order
37
37
  """
38
38
  assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
39
- f"This algorithme is not yet implemented if the sequence contains "
39
+ f"This algorithm is not yet implemented if the sequence contains "
40
40
  f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
41
41
  )
42
42
  number = {e: start - 1 for e in (existing or [])} # noqa: C420
@@ -131,14 +131,14 @@ class GraphRendering:
131
131
  @property
132
132
  def nodes(self) -> List[onnx.NodeProto]:
133
133
  "Returns the list of nodes"
134
- return (
134
+ return list(
135
135
  self.proto.graph.node
136
136
  if isinstance(self.proto, onnx.ModelProto)
137
137
  else self.proto.node
138
138
  )
139
139
 
140
140
  @property
141
- def start_names(self) -> List[onnx.NodeProto]:
141
+ def start_names(self) -> List[str]:
142
142
  "Returns the list of known names, inputs and initializer"
143
143
  graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
144
144
  input_names = (
@@ -151,7 +151,7 @@ class GraphRendering:
151
151
  if isinstance(graph, onnx.FunctionProto)
152
152
  else [
153
153
  *[i.name for i in graph.initializer],
154
- *[i.name for i in graph.sparse_initializer],
154
+ *[i.values.name for i in graph.sparse_initializer],
155
155
  ]
156
156
  )
157
157
  return [*input_names, *init_names]
@@ -159,7 +159,7 @@ class GraphRendering:
159
159
  @property
160
160
  def input_names(self) -> List[str]:
161
161
  "Returns the list of input names."
162
- return (
162
+ return list(
163
163
  self.proto.input
164
164
  if isinstance(self.proto, onnx.FunctionProto)
165
165
  else [
@@ -173,7 +173,7 @@ class GraphRendering:
173
173
  @property
174
174
  def output_names(self) -> List[str]:
175
175
  "Returns the list of output names."
176
- return (
176
+ return list(
177
177
  self.proto.output
178
178
  if isinstance(self.proto, onnx.FunctionProto)
179
179
  else [
@@ -29,10 +29,10 @@ class CubeViewDef:
29
29
  :param order: to reorder key in columns index
30
30
  :param key_agg: aggregate according to these columns before
31
31
  creating the view
32
- :param agg_args: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`,
32
+ :param agg_args: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`,
33
33
  it can be also a callable to return a different aggregation
34
34
  method depending on the column name
35
- :param agg_kwargs: see :meth:`pandas.core.groupby.DataFrameGroupBy.agg`
35
+ :param agg_kwargs: see :meth:`pandas.api.typing.DataFrameGroupBy.agg`
36
36
  :param agg_multi: aggregation over multiple columns
37
37
  :param ignore_columns: ignore the following columns if known to overload the view
38
38
  :param keep_columns_in_index: keeps the columns even if there is only one unique value
@@ -98,7 +98,7 @@ class CubeViewDef:
98
98
  agg_args: Union[Sequence[Any], Callable[[str], Any]] = ("sum",),
99
99
  agg_kwargs: Optional[Dict[str, Any]] = None,
100
100
  agg_multi: Optional[
101
- Dict[str, Callable[[pandas.core.groupby.DataFrameGroupBy], pandas.Series]]
101
+ Dict[str, Callable[[pandas.api.typing.DataFrameGroupBy], pandas.Series]]
102
102
  ] = None,
103
103
  ignore_columns: Optional[Sequence[str]] = None,
104
104
  keep_columns_in_index: Optional[Sequence[str]] = None,
@@ -365,6 +365,7 @@ class CubePlot:
365
365
  # This is very slow
366
366
  # ddd.plot(ax=axs[row, ii],linewidth=3)
367
367
  for jj in range(ddd.shape[1]):
368
+ # pyrefly: ignore[bad-index]
368
369
  axs[row, ii].plot(x, ddd.iloc[:, jj], lw=3, label=ddd.columns[jj])
369
370
  axs[row, ii].set_title(f"{c}{title_suffix}")
370
371
  rotate_align(axs[row, ii])
@@ -480,7 +481,9 @@ class CubeLogs:
480
481
  elif isinstance(self._data, list) and all(isinstance(r, dict) for r in self._data):
481
482
  if verbose:
482
483
  print(f"[CubeLogs.load] load from list of dicts, n={len(self._data)}")
483
- self.data = pandas.DataFrame(self.post_load_process_piece(self._data, unique=True))
484
+ self.data = pandas.DataFrame(
485
+ self.post_load_process_piece(pandas.DataFrame(self._data), unique=True)
486
+ )
484
487
  if verbose:
485
488
  print(f"[CubeLogs.load] after postprocessing shape={self.data.shape}")
486
489
  elif isinstance(self._data, list) and all(
@@ -614,7 +617,7 @@ class CubeLogs:
614
617
 
615
618
  def _process_formula(
616
619
  self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
617
- ) -> Callable[[pandas.DataFrame], pandas.Series]:
620
+ ) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
618
621
  assert callable(formula), f"formula={formula!r} is not supported."
619
622
  return formula
620
623
 
@@ -625,9 +628,11 @@ class CubeLogs:
625
628
  return self.data.shape
626
629
 
627
630
  @property
628
- def columns(self) -> Sequence[str]:
631
+ def columns(self) -> Sequence[Any]:
629
632
  "Returns the columns."
630
633
  assert hasattr(self, "data"), "Method load was not called"
634
+ assert isinstance(self.data, pandas.DataFrame) # type checking
635
+ # pyrefly: ignore[bad-return]
631
636
  return self.data.columns
632
637
 
633
638
  def _preprocess(self):
@@ -647,7 +652,7 @@ class CubeLogs:
647
652
  )
648
653
  assert gr.shape[0] > 0, (
649
654
  f"Something went wrong after the groupby.\n"
650
- f"{cp[[*self.keys, self.time, '__index__']].head().T}"
655
+ f"{cp[[*self.keys_no_time, self.time, '__index__']].head().T}"
651
656
  )
652
657
  filtered = pandas.merge(cp, gr, on=["__index__", *self.keys_time])
653
658
  assert filtered.shape[0] <= self.data.shape[0], (
@@ -797,6 +802,7 @@ class CubeLogs:
797
802
  if view_def.agg_multi:
798
803
  append = []
799
804
  for k, f in view_def.agg_multi.items():
805
+ # pyrefly: ignore[no-matching-overload]
800
806
  cv = grouped_data.apply(f, include_groups=False)
801
807
  append.append(cv.to_frame(k))
802
808
  data = pandas.concat([data, *append], axis=1)
@@ -1020,8 +1026,10 @@ class CubeLogs:
1020
1026
 
1021
1027
  keys = set(self.keys_no_time) - {columns_to_fix}
1022
1028
  select = data[self.keys_no_time]
1029
+ # pyrefly: ignore[no-matching-overload]
1023
1030
  select_agg = select.groupby(list(keys), as_index=True).apply(
1024
- lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))), include_groups=False
1031
+ lambda x: "-".join(sorted(set(x[columns_to_fix].dropna()))),
1032
+ include_groups=False,
1025
1033
  )
1026
1034
  select_agg = select_agg.to_frame(name=columns_to_fix)
1027
1035
  res = pandas.merge(
@@ -1137,6 +1145,7 @@ class CubeLogs:
1137
1145
  if len(nonan) > 0:
1138
1146
  obs.update(dict(count=len(nonan)))
1139
1147
  if is_numeric_dtype(nonan) and not pandas.api.types.is_object_dtype(nonan):
1148
+ # pyrefly: ignore[no-matching-overload]
1140
1149
  obs.update(
1141
1150
  dict(
1142
1151
  min=nonan.min(),
@@ -1208,12 +1217,15 @@ class CubeLogs:
1208
1217
  df.to_excel(writer, sheet_name=main, freeze_panes=(1, 1))
1209
1218
 
1210
1219
  time_mask_view: Dict[str, pandas.DataFrame] = {}
1220
+ df = None
1211
1221
  for name, view in views.items():
1212
1222
  if view is None:
1213
1223
  continue
1214
1224
  df, tview = self.view(view, return_view_def=True, verbose=max(verbose - 1, 0))
1215
1225
  if cube_time is not None:
1216
1226
  cube_mask = cube_time.view(view)
1227
+ assert isinstance(cube_mask, pandas.DataFrame) # type checking
1228
+ assert isinstance(df, pandas.DataFrame) # type checking
1217
1229
  aligned = align_dataframe_with(cube_mask, df)
1218
1230
  if aligned is not None:
1219
1231
  assert aligned.shape == df.shape, (
@@ -1228,6 +1240,7 @@ class CubeLogs:
1228
1240
  )
1229
1241
  if tview is None:
1230
1242
  continue
1243
+ assert isinstance(df, pandas.DataFrame) # type checking
1231
1244
  memory = df.memory_usage(deep=True).sum()
1232
1245
  if verbose:
1233
1246
  print(
@@ -1269,7 +1282,9 @@ class CubeLogs:
1269
1282
  sheet_name=name,
1270
1283
  freeze_panes=(df.columns.nlevels + 1, df.index.nlevels),
1271
1284
  )
1285
+ # pyrefly: ignore[missing-attribute]
1272
1286
  f_highlights[name] = tview.f_highlight
1287
+ # pyrefly: ignore[missing-attribute]
1273
1288
  if tview.plots:
1274
1289
  plots.append(
1275
1290
  CubePlot(
@@ -1282,6 +1297,7 @@ class CubeLogs:
1282
1297
  if self.time in df.columns.names
1283
1298
  else CubePlot(df, kind="barh", orientation="row", split=True)
1284
1299
  )
1300
+ assert isinstance(df, pandas.DataFrame) # type checking
1285
1301
  if raw:
1286
1302
  assert main not in views, f"{main!r} is duplicated in views {sorted(views)}"
1287
1303
  # Too long.
@@ -1439,7 +1455,7 @@ class CubeLogs:
1439
1455
  len(configs) >= 2
1440
1456
  ), f"A side by side needs at least two configs but configs={configs}"
1441
1457
  set_keys_time = set(self.keys_time)
1442
- columns_index = None
1458
+ columns_index: Optional[List[str]] = None
1443
1459
  data_list = []
1444
1460
  for name_conf, conf in configs.items():
1445
1461
  if columns_index is None:
@@ -1478,9 +1494,11 @@ class CubeLogs:
1478
1494
 
1479
1495
  # add metrics
1480
1496
  index_column_name = list(view_res.columns.names).index(column_name)
1497
+ # pyrefly: ignore[missing-attribute]
1481
1498
  index_metrics = list(view_res.columns.names).index("METRICS")
1482
1499
 
1483
1500
  def _mkc(m, s):
1501
+ # pyrefly: ignore[missing-attribute]
1484
1502
  c = ["" for c in view_res.columns.names]
1485
1503
  c[index_column_name] = s
1486
1504
  c[index_metrics] = m
@@ -1515,7 +1533,9 @@ class CubeLogs:
1515
1533
  ci["CONF"] = iname
1516
1534
  cj["CONF"] = jname
1517
1535
 
1536
+ # pyrefly: ignore[bad-index]
1518
1537
  ci_name = tuple(ci[n] for n in view_res.columns.names)
1538
+ # pyrefly: ignore[bad-index]
1519
1539
  cj_name = tuple(cj[n] for n in view_res.columns.names)
1520
1540
  assert ci_name in view_res.columns or cj_name in view_res.columns, (
1521
1541
  f"Unable to find column {ci_name} or {cj_name} "
@@ -1562,6 +1582,7 @@ class CubeLogs:
1562
1582
  }
1563
1583
  flat = view_res.groupby(self.time).agg(aggs)
1564
1584
  flat = flat.stack("METRICS", future_stack=True)
1585
+ # pyrefly: ignore[bad-return, missing-attribute]
1565
1586
  return res, flat, view_res.T.sort_index().T
1566
1587
 
1567
1588
 
@@ -1679,7 +1700,7 @@ class CubeLogsPerformance(CubeLogs):
1679
1700
 
1680
1701
  def _process_formula(
1681
1702
  self, formula: Union[str, Callable[[pandas.DataFrame], pandas.Series]]
1682
- ) -> Callable[[pandas.DataFrame], pandas.Series]:
1703
+ ) -> Callable[[pandas.DataFrame], Optional[pandas.Series]]:
1683
1704
  """
1684
1705
  Processes a formula, converting it into a function.
1685
1706
 
@@ -1726,6 +1747,7 @@ class CubeLogsPerformance(CubeLogs):
1726
1747
  f"{pprint.pformat(sorted(columns))}"
1727
1748
  )
1728
1749
  # return lambda df: df["time_latency_eager"] / df["time_latency"]
1750
+ # pyrefly: ignore[no-matching-overload]
1729
1751
  return lambda df: pandas.cut(
1730
1752
  df["speedup"], bins=BUCKET_SCALES, right=False, duplicates="raise"
1731
1753
  )
@@ -1733,9 +1755,9 @@ class CubeLogsPerformance(CubeLogs):
1733
1755
  if formula == "ERR1":
1734
1756
  columns = set(self._filter_column(["^ERR_.*"], self.data.columns))
1735
1757
  if not columns:
1736
- return lambda df: np.nan
1758
+ return lambda df: None
1737
1759
 
1738
- def first_err(df: pandas.DataFrame) -> pandas.Series:
1760
+ def first_err(df: pandas.DataFrame) -> Optional[pandas.Series]:
1739
1761
  ordered = [
1740
1762
  c
1741
1763
  for c in [
@@ -1752,7 +1774,7 @@ class CubeLogsPerformance(CubeLogs):
1752
1774
  ]
1753
1775
  if c in df.columns
1754
1776
  ]
1755
- res = None
1777
+ res: Optional[pandas.Series] = None
1756
1778
  for c in ordered:
1757
1779
  if res is None:
1758
1780
  res = df[c].fillna("")
@@ -1949,6 +1971,7 @@ class CubeLogsPerformance(CubeLogs):
1949
1971
  f"{pprint.pformat(sorted(self.data.columns))}"
1950
1972
  )
1951
1973
 
1974
+ # pyrefly: ignore[bad-override]
1952
1975
  def view(
1953
1976
  self,
1954
1977
  view_def: Optional[Union[str, CubeViewDef]],
@@ -2265,7 +2288,7 @@ class CubeLogsPerformance(CubeLogs):
2265
2288
  if unique:
2266
2289
  return df
2267
2290
  cols = self._filter_column(self._keys, df)
2268
- res = None
2291
+ res: Optional[pandas.DataFrame] = None
2269
2292
  for c in cols:
2270
2293
  if df[c].isna().any():
2271
2294
  # Missing values for keys are not supposed to happen.
@@ -103,6 +103,7 @@ def _process_memory_spy(conn):
103
103
  process = psutil.Process(pid)
104
104
 
105
105
  if cuda:
106
+ # pyrefly: ignore[missing-import]
106
107
  from pynvml import (
107
108
  nvmlDeviceGetCount,
108
109
  nvmlDeviceGetHandleByIndex,
@@ -131,6 +132,7 @@ def _process_memory_spy(conn):
131
132
  mem = process.memory_info().rss
132
133
  cpu.update(mem)
133
134
  if cuda:
135
+ # pyrefly: ignore[unbound-name]
134
136
  for r, g in zip(gpu_used(), gpus):
135
137
  g.update(r)
136
138
  if conn.poll(timeout=timeout):
@@ -142,6 +144,7 @@ def _process_memory_spy(conn):
142
144
  end = process.memory_info().rss
143
145
  cpu.update(end)
144
146
  if cuda:
147
+ # pyrefly: ignore[unbound-name]
145
148
  for r, g in zip(gpu_used(), gpus):
146
149
  g.update(r)
147
150
 
@@ -151,6 +154,7 @@ def _process_memory_spy(conn):
151
154
  for g in gpus:
152
155
  g.send(conn)
153
156
  if cuda:
157
+ # pyrefly: ignore[unbound-name]
154
158
  nvmlShutdown()
155
159
  conn.close()
156
160
 
@@ -217,7 +221,7 @@ def start_spying_on(
217
221
  Starts the memory spy. The function starts another
218
222
  process spying on the one sent as an argument.
219
223
 
220
- :param pid: process id to spy or the the current one.
224
+ :param pid: process id to spy or the current one.
221
225
  :param delay: delay between two measures.
222
226
  :param cuda: True or False to get memory for cuda devices
223
227
 
@@ -8,11 +8,6 @@ import torch
8
8
  from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended
9
9
  from . import string_type
10
10
 
11
- STORAGE_TYPE = {
12
- TensorProto.FLOAT16: np.int16,
13
- TensorProto.BFLOAT16: np.int16,
14
- }
15
-
16
11
 
17
12
  def proto_from_array(
18
13
  arr: torch.Tensor,
@@ -67,13 +62,13 @@ def proto_from_array(
67
62
  byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
68
63
  tensor.raw_data = bytes(byte_data)
69
64
  if sys.byteorder == "big":
70
- np_dtype = tensor_dtype_to_np_dtype(STORAGE_TYPE[tensor.data_type])
71
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
65
+ np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
66
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
72
67
  else:
73
68
  tensor.raw_data = np_arr.tobytes()
74
69
  if sys.byteorder == "big":
75
70
  np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
76
- np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
71
+ np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
77
72
 
78
73
  return tensor
79
74
 
@@ -133,6 +128,7 @@ class MiniOnnxBuilder:
133
128
  }
134
129
  shape = tuple(map(int, tensor.shape))
135
130
  self.nodes.append(
131
+ # pyrefly: ignore[bad-argument-type]
136
132
  oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs)
137
133
  )
138
134
  self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
@@ -632,6 +628,7 @@ def create_input_tensors_from_onnx_model(
632
628
  raise AssertionError(f"Unexpected value for engine={engine!r}")
633
629
 
634
630
  got = sess.run(None, {})
631
+ assert isinstance(got, list) # type checking
635
632
  if len(names) == 1:
636
633
  name = names[0]
637
634
  output = got[0]
@@ -639,12 +636,10 @@ def create_input_tensors_from_onnx_model(
639
636
  return None
640
637
  if name == "array":
641
638
  return output
642
- if name == "bool":
643
- return bool(output[0])
644
- if name == "int":
645
- return int(output[0])
646
- if name == "float":
647
- return float(output[0])
639
+ if name in {"bool", "int", "float"}:
640
+ cvt = {"bool": bool, "int": int, "float": float}[name]
641
+ # pyrefly: ignore[bad-index]
642
+ return cvt(output[0])
648
643
  if name == "tensor":
649
644
  return torch.from_numpy(output).to(device)
650
645
  assert name.startswith(