onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +22 -5
  3. onnx_diagnostic/ext_test_case.py +31 -0
  4. onnx_diagnostic/helpers/cache_helper.py +23 -12
  5. onnx_diagnostic/helpers/config_helper.py +16 -1
  6. onnx_diagnostic/helpers/log_helper.py +308 -83
  7. onnx_diagnostic/helpers/rt_helper.py +11 -1
  8. onnx_diagnostic/helpers/torch_helper.py +7 -3
  9. onnx_diagnostic/tasks/__init__.py +2 -0
  10. onnx_diagnostic/tasks/text_generation.py +17 -8
  11. onnx_diagnostic/tasks/text_to_image.py +91 -0
  12. onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
  15. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
  16. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  17. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  18. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  19. onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
  20. onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
  23. onnx_diagnostic/torch_models/validate.py +36 -12
  24. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
  25. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
  26. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.1"
6
+ __version__ = "0.7.3"
7
7
  __author__ = "Xavier Dupré"
@@ -718,13 +718,13 @@ def get_parser_agg() -> ArgumentParser:
718
718
  "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
719
719
  "n_node_constant,n_node_shape,n_node_expand,"
720
720
  "n_node_function,n_node_initializer,n_node_scatter,"
721
- "time_export_unbiased",
721
+ "time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
722
722
  help="Columns to compute after the aggregation was done.",
723
723
  )
724
724
  parser.add_argument(
725
725
  "--views",
726
726
  default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
727
- "bucket-speedup,raw-short,counts,peak-gpu",
727
+ "bucket-speedup,raw-short,counts,peak-gpu,onnx",
728
728
  help="Views to add to the output files.",
729
729
  )
730
730
  parser.add_argument(
@@ -733,11 +733,28 @@ def get_parser_agg() -> ArgumentParser:
733
733
  help="Views to dump as csv files.",
734
734
  )
735
735
  parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
736
+ parser.add_argument(
737
+ "--filter-in",
738
+ default="",
739
+ help="adds a filter to filter in data, syntax is\n"
740
+ '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
741
+ )
742
+ parser.add_argument(
743
+ "--filter-out",
744
+ default="",
745
+ help="adds a filter to filter out data, syntax is\n"
746
+ '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
747
+ )
736
748
  return parser
737
749
 
738
750
 
739
751
  def _cmd_agg(argv: List[Any]):
740
- from .helpers.log_helper import CubeLogsPerformance, open_dataframe, enumerate_csv_files
752
+ from .helpers.log_helper import (
753
+ CubeLogsPerformance,
754
+ open_dataframe,
755
+ enumerate_csv_files,
756
+ filter_data,
757
+ )
741
758
 
742
759
  parser = get_parser_agg()
743
760
  args = parser.parse_args(argv[1:])
@@ -748,7 +765,7 @@ def _cmd_agg(argv: List[Any]):
748
765
  args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
749
766
  )
750
767
  )
751
- assert csv, f"No csv files in {args.inputs}, csv={csv}"
768
+ assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
752
769
  if args.verbose:
753
770
  from tqdm import tqdm
754
771
 
@@ -761,7 +778,7 @@ def _cmd_agg(argv: List[Any]):
761
778
  assert (
762
779
  args.time in df.columns
763
780
  ), f"Missing time column {args.time!r} in {c!r}\n{df.head()}\n{sorted(df.columns)}"
764
- dfs.append(df)
781
+ dfs.append(filter_data(df, filter_in=args.filter_in, filter_out=args.filter_out))
765
782
 
766
783
  drop_keys = set(args.drop_keys.split(","))
767
784
  cube = CubeLogsPerformance(
@@ -756,6 +756,18 @@ class ExtTestCase(unittest.TestCase):
756
756
  "Adds a todo printed when all test are run."
757
757
  cls._todos.append((f, msg))
758
758
 
759
+ @classmethod
760
+ def ort(cls):
761
+ import onnxruntime
762
+
763
+ return onnxruntime
764
+
765
+ @classmethod
766
+ def to_onnx(self, *args, **kwargs):
767
+ from experimental_experiment.torch_interpreter import to_onnx
768
+
769
+ return to_onnx(*args, **kwargs)
770
+
759
771
  def print_model(self, model: "ModelProto"): # noqa: F821
760
772
  "Prints a ModelProto"
761
773
  from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
@@ -917,6 +929,15 @@ class ExtTestCase(unittest.TestCase):
917
929
  ]
918
930
  raise AssertionError("\n".join(rows)) # noqa: B904
919
931
 
932
+ def assertEqualDataFrame(self, d1, d2, **kwargs):
933
+ """
934
+ Checks that two dataframes are equal.
935
+ Calls :func:`pandas.testing.assert_frame_equal`.
936
+ """
937
+ from pandas.testing import assert_frame_equal
938
+
939
+ assert_frame_equal(d1, d2, **kwargs)
940
+
920
941
  def assertEqualTrue(self, value: Any, msg: str = ""):
921
942
  if value is True:
922
943
  return
@@ -967,6 +988,16 @@ class ExtTestCase(unittest.TestCase):
967
988
  atol=atol,
968
989
  rtol=rtol,
969
990
  )
991
+ elif expected.__class__.__name__ == "StaticCache":
992
+ self.assertEqual(type(expected), type(value), msg=msg)
993
+ self.assertEqual(expected.max_cache_len, value.max_cache_len)
994
+ atts = ["key_cache", "value_cache"]
995
+ self.assertEqualAny(
996
+ {k: expected.__dict__.get(k, None) for k in atts},
997
+ {k: value.__dict__.get(k, None) for k in atts},
998
+ atol=atol,
999
+ rtol=rtol,
1000
+ )
970
1001
  elif expected.__class__.__name__ == "EncoderDecoderCache":
971
1002
  self.assertEqual(type(expected), type(value), msg=msg)
972
1003
  atts = ["self_attention_cache", "cross_attention_cache"]
@@ -154,10 +154,12 @@ else:
154
154
 
155
155
  def make_static_cache(
156
156
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
157
+ max_cache_len: Optional[int] = None,
157
158
  ) -> transformers.cache_utils.DynamicCache:
158
159
  """
159
160
  Creates an instance of :class:`transformers.cache_utils.StaticCache`.
160
161
  :param key_value_pairs: list of pairs of (key, values)
162
+ :param max_cache_len: max_cache_length or something inferred from the vector
161
163
  :return: :class:`transformers.cache_utils.StaticCache`
162
164
 
163
165
  Example:
@@ -179,7 +181,8 @@ def make_static_cache(
179
181
  torch.randn(bsize, nheads, slen, dim),
180
182
  )
181
183
  for i in range(n_layers)
182
- ]
184
+ ],
185
+ max_cache_len=10,
183
186
  )
184
187
  print(string_type(past_key_values, with_shape=True))
185
188
  """
@@ -190,24 +193,32 @@ def make_static_cache(
190
193
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
191
194
  self.num_hidden_layers = len(key_value_pairs)
192
195
 
196
+ assert max_cache_len is not None, (
197
+ f"max_cache_len={max_cache_len} cannot be setup "
198
+ f"automatically yet from shape {key_value_pairs[0][0].shape}"
199
+ )
200
+ torch._check(
201
+ max_cache_len >= key_value_pairs[0][0].shape[2],
202
+ (
203
+ f"max_cache_len={max_cache_len} cannot be smaller "
204
+ f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
205
+ f"{key_value_pairs[0][0].shape}"
206
+ ),
207
+ )
193
208
  cache = transformers.cache_utils.StaticCache(
194
209
  _config(),
195
210
  max_batch_size=key_value_pairs[0][0].shape[0],
196
211
  device=key_value_pairs[0][0].device,
197
212
  dtype=key_value_pairs[0][0].dtype,
198
- max_cache_len=key_value_pairs[0][0].shape[2],
213
+ max_cache_len=max_cache_len,
199
214
  )
200
215
  for i in range(len(key_value_pairs)):
201
- assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
202
- f"Shape mismatch, expected {cache.key_cache[i].shape}, "
203
- f"got {key_value_pairs[i][0].shape}"
204
- )
205
- cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
206
- assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
207
- f"Shape mismatch, expected {cache.value_cache[i].shape}, "
208
- f"got {key_value_pairs[i][1].shape}"
209
- )
210
- cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
216
+ assert (
217
+ key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
218
+ ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
219
+ d = key_value_pairs[i][1].shape[2]
220
+ cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
221
+ cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
211
222
  return cache
212
223
 
213
224
 
@@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
43
43
  else:
44
44
  update_config(getattr(config, k), v)
45
45
  continue
46
- setattr(config, k, v)
46
+ if type(config) is dict:
47
+ config[k] = v
48
+ else:
49
+ setattr(config, k, v)
47
50
 
48
51
 
49
52
  def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
@@ -66,6 +69,18 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
66
69
  raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
67
70
 
68
71
 
72
+ def pick(config, name: str, default_value: Any) -> Any:
73
+ """
74
+ Returns the value of a attribute if config has it
75
+ otherwise the default value.
76
+ """
77
+ if not config:
78
+ return default_value
79
+ if type(config) is dict:
80
+ return config.get(name, default_value)
81
+ return getattr(config, name, default_value)
82
+
83
+
69
84
  @functools.cache
70
85
  def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
71
86
  """