onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +5 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +65 -12
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +55 -37
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +15 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
  24. onnx_diagnostic/torch_models/validate.py +337 -113
  25. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  26. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
  27. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.11"
6
+ __version__ = "0.7.13"
7
7
  __author__ = "Xavier Dupré"
@@ -581,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
581
581
  ):
582
582
  print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
583
583
  return
584
+ patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
584
585
  summary, _data = validate_model(
585
586
  model_id=args.mid,
586
587
  task=args.task,
@@ -591,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
591
592
  use_pretrained=args.trained,
592
593
  dtype=args.dtype,
593
594
  device=args.device,
594
- patch=args.patch,
595
- rewrite=args.rewrite,
595
+ patch=patch_dict,
596
+ rewrite=args.rewrite and patch_dict.get("patch", True),
596
597
  stop_if_static=args.stop_if_static,
597
598
  optimization=args.opt,
598
599
  exporter=args.export,
@@ -827,6 +828,8 @@ def get_parser_agg() -> ArgumentParser:
827
828
  "n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
828
829
  "n_model_pass,n_model_faster,"
829
830
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
831
+ "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
832
+ "n_node_layer_normalization,n_node_layer_normalization23,"
830
833
  "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
831
834
  "n_node_constant,n_node_shape,n_node_expand,"
832
835
  "n_node_function,n_node_initializer,n_node_scatter,"
@@ -56,6 +56,14 @@ class CoupleInputsDynamicShapes:
56
56
  self.kwargs = kwargs
57
57
  self.dynamic_shapes = dynamic_shapes
58
58
  self.args_names = args_names
59
+ if not self.kwargs and isinstance(self.dynamic_shapes, dict):
60
+ # This assumes the dictionary for the dynamic shapes is ordered
61
+ # the same way the args are. The input names are not known.
62
+ assert len(self.dynamic_shapes) == len(self.args), (
63
+ f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
64
+ f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
65
+ )
66
+ self.dynamic_shapes = tuple(self.dynamic_shapes.values())
59
67
 
60
68
  def __str__(self) -> str:
61
69
  return "\n".join(
@@ -232,8 +240,9 @@ class CoupleInputsDynamicShapes:
232
240
  """
233
241
  if not self.args:
234
242
  assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
235
- f"Type mismatch, args={string_type(self.args)} and "
236
- f"dynamic_shapes={self.dynamic_shapes} should have the same type."
243
+ f"Type mismatch, args={string_type(self.args)}, "
244
+ f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
245
+ f"{string_type(self.dynamic_shapes)} should have the same type."
237
246
  )
238
247
  res = self._generic_walker_step(
239
248
  processor,
@@ -397,7 +397,7 @@ def string_type(
397
397
  return "AUTO"
398
398
  if verbose:
399
399
  print(f"[string_type] Y7:{type(obj)}")
400
- return str(obj)
400
+ return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
401
401
 
402
402
  if isinstance(obj, bool):
403
403
  if with_min_max:
@@ -516,8 +516,10 @@ def string_type(
516
516
  print(f"[string_type] V2:{type(obj)}")
517
517
  return "OV(NOTENSOR)"
518
518
  if with_min_max:
519
+ from .torch_helper import to_numpy
520
+
519
521
  try:
520
- t = obj.numpy()
522
+ t = to_numpy(obj)
521
523
  except Exception:
522
524
  # pass unable to convert into numpy (bfloat16, ...)
523
525
  if verbose:
@@ -939,7 +941,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
939
941
  return flatten_object(list(x.values()), drop_keys=drop_keys)
940
942
  return flatten_object(list(x.items()), drop_keys=drop_keys)
941
943
 
942
- if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
944
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
943
945
  from .cache_helper import CacheKeyValue
944
946
 
945
947
  kc = CacheKeyValue(x)
@@ -1233,9 +1235,13 @@ def max_diff(
1233
1235
 
1234
1236
  if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
1235
1237
  if isinstance(expected, torch.Tensor):
1236
- expected = expected.detach().cpu().numpy()
1238
+ from .torch_helper import to_numpy
1239
+
1240
+ expected = to_numpy(expected)
1237
1241
  if isinstance(got, torch.Tensor):
1238
- got = got.detach().cpu().numpy()
1242
+ from .torch_helper import to_numpy
1243
+
1244
+ got = to_numpy(got)
1239
1245
  if verbose >= 6:
1240
1246
  print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1241
1247
 
@@ -285,7 +285,8 @@ class CubePlot:
285
285
  nn = df.shape[1] // n_cols
286
286
  nn += int(df.shape[1] % n_cols != 0)
287
287
  ratio = float(os.environ.get("FIGSIZEH", "1"))
288
- fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 3 * ratio))
288
+ figsize = (6 * n_cols, nn * (2.5 + df.shape[0] / 15) * ratio)
289
+ fig, axs = plt.subplots(nn, n_cols, figsize=figsize)
289
290
  pos = 0
290
291
  imgs = []
291
292
  for c in self._make_loop(df.columns, verbose):
@@ -332,10 +333,12 @@ class CubePlot:
332
333
  n_cols = len(groups)
333
334
 
334
335
  title_suffix = f"\n{title_suffix}" if title_suffix else ""
336
+ ratio = float(os.environ.get("FIGSIZEH", "1"))
337
+ figsize = (5 * n_cols, max(len(g) for g in groups) * (2 + df.shape[1] / 2) * ratio)
335
338
  fig, axs = plt.subplots(
336
339
  df.shape[1],
337
340
  n_cols,
338
- figsize=(5 * n_cols, max(len(g) for g in groups) * df.shape[1] / 2),
341
+ figsize=figsize,
339
342
  sharex=True,
340
343
  sharey="row" if n_cols > 1 else False,
341
344
  )
@@ -877,7 +880,11 @@ class CubeLogs:
877
880
  print(f"[CubeLogs.view] key_columns={key_columns}")
878
881
  g = data[[*key_index, *key_columns]].copy()
879
882
  g["count"] = 1
880
- r = g.groupby([*key_index, *key_columns], dropna=False).sum()
883
+ r = (
884
+ g.copy()
885
+ if not key_index and not key_columns
886
+ else g.groupby([*key_index, *key_columns], dropna=False).sum()
887
+ )
881
888
  not_unique = r[r["count"] > 1]
882
889
  assert not_unique.shape[0] == 0, (
883
890
  f"view_def.name={view_def.name!r}, "
@@ -1505,6 +1512,11 @@ class CubeLogsPerformance(CubeLogs):
1505
1512
  "n_model_faster3x",
1506
1513
  "n_model_faster4x",
1507
1514
  "n_node_attention",
1515
+ "n_node_attention23",
1516
+ "n_node_rotary_embedding",
1517
+ "n_node_rotary_embedding23",
1518
+ "n_node_layer_normalization",
1519
+ "n_node_layer_normalization23",
1508
1520
  "n_node_control_flow",
1509
1521
  "n_node_scatter",
1510
1522
  "n_node_function",
@@ -1568,7 +1580,9 @@ class CubeLogsPerformance(CubeLogs):
1568
1580
 
1569
1581
  def gdf(df, cname, default_value=np.nan):
1570
1582
  if cname in df.columns:
1571
- return df[cname]
1583
+ if np.isnan(default_value):
1584
+ return df[cname]
1585
+ return df[cname].fillna(default_value)
1572
1586
  return pandas.Series(default_value, index=df.index)
1573
1587
 
1574
1588
  def ghas_value(df, cname):
@@ -1676,15 +1690,54 @@ class CubeLogsPerformance(CubeLogs):
1676
1690
  "time_latency",
1677
1691
  gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
1678
1692
  ),
1693
+ n_node_attention23=lambda df: gpreserve(
1694
+ df, "time_latency_eager", gdf(df, "op_onnx__Attention")
1695
+ ),
1696
+ n_node_rotary_embedding23=lambda df: gpreserve(
1697
+ df, "time_latency_eager", gdf(df, "op_onnx__RotaryEmbedding")
1698
+ ),
1699
+ n_node_layer_normalization23=lambda df: gpreserve(
1700
+ df,
1701
+ "time_latency_eager",
1702
+ gdf(df, "op_onnx__LayerNormalization", 0)
1703
+ + gdf(df, "op_onnx__RMSNormalization", 0)
1704
+ + gdf(df, "op_onnx__BatchNormlization", 0)
1705
+ + gdf(df, "op_onnx__InstanceNormlization", 0)
1706
+ + gdf(df, "op_onnx__GroupNormalization", 0),
1707
+ ),
1679
1708
  n_node_attention=lambda df: gpreserve(
1680
1709
  df,
1681
- "op_onnx_com.microsoft_Attention",
1682
- gdf(df, "op_onnx_com.microsoft_Attention")
1683
- + gdf(df, "op_onnx_com.microsoft_MultiHeadAttention"),
1710
+ "time_latency_eager",
1711
+ gdf(df, "op_onnx_com.microsoft_Attention", 0)
1712
+ + gdf(df, "op_onnx_com.microsoft_MultiHeadAttention", 0)
1713
+ + gdf(df, "op_onnx_com.microsoft_PackedAttention", 0)
1714
+ + gdf(df, "op_onnx_com.microsoft_PackedMultiHeadAttention", 0)
1715
+ + gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0)
1716
+ + gdf(df, "op_onnx_com.microsoft_PagedAttention", 0)
1717
+ + gdf(df, "op_onnx_com.microsoft_DecoderAttention", 0)
1718
+ + gdf(df, "op_onnx_com.microsoft_LongformerAttention", 0)
1719
+ + gdf(df, "op_onnx_com.microsoft_DecoderMaskedSelfAttention", 0)
1720
+ + gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
1721
+ + gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
1722
+ ),
1723
+ n_node_layer_normalization=lambda df: gpreserve(
1724
+ df,
1725
+ "time_latency_eager",
1726
+ gdf(df, "op_onnx_com.microsoft_EmbedLayerNormalization", 0)
1727
+ + gdf(df, "op_onnx_com.microsoft_SkipLayerNormalization", 0)
1728
+ + gdf(df, "op_onnx_com.microsoft_LayerNormalization", 0)
1729
+ + gdf(df, "op_onnx_com.microsoft_SkipSimplifiedLayerNormalization", 0)
1730
+ + gdf(df, "op_onnx_com.microsoft_SimplifiedLayerNormalization", 0),
1731
+ ),
1732
+ n_node_rotary_embedding=lambda df: gpreserve(
1733
+ df,
1734
+ "time_latency_eager",
1735
+ gdf(df, "op_onnx_com.microsoft_GemmaRotaryEmbedding", 0)
1736
+ + gdf(df, "op_onnx_com.microsoft_RotaryEmbedding", 0),
1684
1737
  ),
1685
1738
  n_node_control_flow=lambda df: gpreserve(
1686
1739
  df,
1687
- "op_onnx__If",
1740
+ "time_latency_eager",
1688
1741
  (
1689
1742
  gdf(df, "op_onnx__If", 0)
1690
1743
  + gdf(df, "op_onnx__Scan", 0)
@@ -1693,7 +1746,7 @@ class CubeLogsPerformance(CubeLogs):
1693
1746
  ),
1694
1747
  n_node_scatter=lambda df: gpreserve(
1695
1748
  df,
1696
- "op_onnx__ScatterND",
1749
+ "time_latency_eager",
1697
1750
  gdf(df, "op_onnx__ScatterND", 0) + gdf(df, "op_onnx__ScatterElements", 0),
1698
1751
  ),
1699
1752
  n_node_function=lambda df: gpreserve(
@@ -1706,13 +1759,13 @@ class CubeLogsPerformance(CubeLogs):
1706
1759
  df, "onnx_n_initializer", gdf(df, "onnx_n_initializer")
1707
1760
  ),
1708
1761
  n_node_constant=lambda df: gpreserve(
1709
- df, "op_onnx__Constant", gdf(df, "op_onnx__Constant")
1762
+ df, "time_latency_eager", gdf(df, "op_onnx__Constant")
1710
1763
  ),
1711
1764
  n_node_shape=lambda df: gpreserve(
1712
- df, "op_onnx__Shape", gdf(df, "op_onnx__Shape")
1765
+ df, "time_latency_eager", gdf(df, "op_onnx__Shape")
1713
1766
  ),
1714
1767
  n_node_expand=lambda df: gpreserve(
1715
- df, "op_onnx__Expand", gdf(df, "op_onnx__Expand")
1768
+ df, "time_latency_eager", gdf(df, "op_onnx__Expand")
1716
1769
  ),
1717
1770
  )
1718
1771
  assert (
@@ -381,6 +381,23 @@ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
381
381
  else:
382
382
  for p, o in _flatten_iterator(getattr(obj, att), sep):
383
383
  yield f"DynamicCache_{att}{sep}{p}", o
384
+ elif obj.__class__.__name__ == "StaticCache":
385
+ # transformers
386
+ import transformers
387
+ from .cache_helper import CacheKeyValue
388
+
389
+ assert isinstance(
390
+ obj, transformers.cache_utils.StaticCache
391
+ ), f"Unexpected type {type(obj)}"
392
+ obj = CacheKeyValue(obj)
393
+ atts = ["key_cache", "value_cache"]
394
+ for i, att in enumerate(atts):
395
+ if i == len(atts) - 1:
396
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
397
+ yield f"StaticCache._{att}{sep}{p}", o
398
+ else:
399
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
400
+ yield f"StaticCache_{att}{sep}{p}", o
384
401
  else:
385
402
  raise NotImplementedError(f"Unexpected type {type(obj)}")
386
403
 
@@ -203,6 +203,7 @@ def create_model_builder(
203
203
  "ChatGLMModel": builder.ChatGLMModel,
204
204
  "Ernie4_5_ForCausalLM": builder.ErnieModel,
205
205
  "GemmaForCausalLM": builder.Gemma2Model,
206
+ "Gemma2ForCausalLM": builder.Gemma2Model,
206
207
  "Gemma3ForCausalLM": builder.Gemma3Model,
207
208
  "Gemma3ForConditionalGeneration": builder.Gemma3Model,
208
209
  "GraniteForCausalLM": builder.GraniteModel,
@@ -3,7 +3,7 @@ import numpy as np
3
3
  import onnx
4
4
  import torch
5
5
  from .helper import string_type, flatten_object
6
- from .onnx_helper import dtype_to_tensor_dtype
6
+ from .torch_helper import to_numpy
7
7
  from .cache_helper import is_cache_dynamic_registered
8
8
 
9
9
 
@@ -23,6 +23,7 @@ def make_feeds(
23
23
  use_numpy: bool = False,
24
24
  copy: bool = False,
25
25
  check_flatten: bool = True,
26
+ is_modelbuilder: bool = False,
26
27
  ) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
27
28
  """
28
29
  Serializes the inputs to produce feeds expected
@@ -35,10 +36,15 @@ def make_feeds(
35
36
  by ``OrtValue``
36
37
  :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
37
38
  returns the same number of outputs
39
+ :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
40
+ the past_key_values inputs to match the expected order, and get rid of position_ids.
38
41
  :return: feeds dictionary
39
42
  """
40
- # position_ids is a special case because ModelBuilder does not usually use it.
41
- # We use types to detect the best inputs.
43
+ # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
44
+ # because it's fued into rotary embedding in GQA.
45
+ if is_modelbuilder and isinstance(inputs, dict):
46
+ inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
47
+
42
48
  flat = flatten_object(inputs, drop_keys=True)
43
49
  assert (
44
50
  not check_flatten
@@ -51,7 +57,7 @@ def make_feeds(
51
57
  f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
52
58
  )
53
59
  if use_numpy:
54
- flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
60
+ flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
55
61
  names = (
56
62
  [i.name for i in proto.graph.input]
57
63
  if isinstance(proto, onnx.ModelProto)
@@ -76,39 +82,6 @@ def make_feeds(
76
82
  f"\n-- inputs={string_type(inputs, with_shape=True)}"
77
83
  f"\n-- names={names}"
78
84
  )
79
- if len(names) < len(flat) and (
80
- isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
81
- ):
82
-
83
- typed_names = (
84
- [(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
85
- if isinstance(proto, onnx.ModelProto)
86
- else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
87
- )
88
-
89
- new_flat = []
90
- pos = 0
91
- for _name, dtype in typed_names:
92
- assert isinstance(
93
- dtype, int
94
- ), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
95
- itype = dtype_to_tensor_dtype(flat[pos].dtype)
96
- while dtype != itype:
97
- pos += 1
98
- if pos >= len(flat):
99
- break
100
- itype = dtype_to_tensor_dtype(flat[pos].dtype)
101
- if pos >= len(flat):
102
- break
103
- new_flat.append(flat[pos])
104
- pos += 1
105
- assert len(new_flat) == len(names), (
106
- f"Unable to align expected input {names} with the given input, "
107
- f"type(proto)={type(proto)}"
108
- f"\n-- inputs: {string_type(inputs, with_shape=True)}"
109
- f"\n-- typed_names: {typed_names}"
110
- )
111
- flat = new_flat
112
85
 
113
86
  if copy:
114
87
  flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
@@ -122,4 +95,49 @@ def make_feeds(
122
95
  elif isinstance(i, float):
123
96
  i = np.array(i, dtype=np.float32)
124
97
  new_flat.append(i)
98
+
99
+ # NOTE: model builder has a different order for past_key_values
100
+ # we need to reorder them to match the expected order
101
+ if is_modelbuilder:
102
+ # We assume that if "past_key_values" is in the names when it's
103
+ # modelbuilder
104
+ non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
105
+ past_kv_names = [n for n in names if "past_key_values" in n]
106
+ reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
107
+ names = non_past_kv_input_names + reorder_past_kv_names
125
108
  return dict(zip(names, new_flat))
109
+
110
+
111
+ def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
112
+ """
113
+ Reorders the past_kvs for ModelBuilder to match the expected order
114
+ by PyTorch exported models.
115
+
116
+ .. note::
117
+ This function can take either the names or the actual tensors
118
+ as long as they are in a list.
119
+
120
+ Conceptually,
121
+
122
+ From::
123
+
124
+ [past_key_values.0.key, past_key_values.0.value,
125
+ past_key_values.1.key, past_key_values.1.value, ...]
126
+
127
+ To::
128
+
129
+ [past_key_values.0.key, past_key_values.1.key,
130
+ ..., past_key_values.0.value, past_key_values.1.value, ...]
131
+
132
+ :param past_kv: list of flattened inputs
133
+ :return: reordered list of flattened inputs
134
+ """
135
+ total_len = len(past_kv)
136
+ if total_len % 2 != 0:
137
+ raise ValueError("The length of past_key_values should be even.")
138
+ keys = []
139
+ values = []
140
+ for i in range(0, total_len, 2):
141
+ keys.append(past_kv[i])
142
+ values.append(past_kv[i + 1])
143
+ return keys + values
@@ -5,7 +5,7 @@ import os
5
5
  import sys
6
6
  import warnings
7
7
  from collections.abc import Iterable
8
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
9
9
  import numpy as np
10
10
  import onnx
11
11
  from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
@@ -283,9 +283,11 @@ def steal_forward(
283
283
  ],
284
284
  fprint: Callable = string_type,
285
285
  dump_file: Optional[str] = None,
286
+ dump_drop: Optional[Set[str]] = None,
286
287
  submodules: bool = False,
287
288
  verbose: int = 0,
288
289
  storage_limit: int = 2**27,
290
+ save_as_external_data: bool = True,
289
291
  **kwargs,
290
292
  ):
291
293
  """
@@ -303,6 +305,9 @@ def steal_forward(
303
305
  :param dump_file: dumps stolen inputs and outputs in an onnx model,
304
306
  they can be restored with :func:`create_input_tensors_from_onnx_model
305
307
  <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
308
+ :param dump_drop: to drop some inputs too big (only if dump_file is specified)
309
+ :param save_as_external_data: True by default, but maybe better to have everything
310
+ in a single file if possible
306
311
  :param submodules: if True and model is a module, the list extended with all the submodules
307
312
  the module contains
308
313
  :param verbose: verbosity
@@ -411,6 +416,15 @@ def steal_forward(
411
416
  if verbose:
412
417
  size = torch_tensor_size(storage)
413
418
  print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
419
+ if dump_drop:
420
+ for k, v in storage.items():
421
+ if k[-1] == "I":
422
+ _args, kwargs = v
423
+ ii = set(kwargs) & dump_drop
424
+ if ii:
425
+ for i in ii:
426
+ print("---", i)
427
+ del kwargs[i]
414
428
  proto = create_onnx_model_from_input_tensors(storage)
415
429
  if verbose:
416
430
  print("-- dumps stored objects")
@@ -420,7 +434,7 @@ def steal_forward(
420
434
  onnx.save(
421
435
  proto,
422
436
  dump_file,
423
- save_as_external_data=True,
437
+ save_as_external_data=save_as_external_data,
424
438
  all_tensors_to_one_file=True,
425
439
  location=location,
426
440
  )
@@ -464,10 +478,10 @@ def is_torchdynamo_exporting() -> bool:
464
478
  return False
465
479
 
466
480
 
467
- def to_numpy(tensor: "torch.Tensor"): # noqa: F821
481
+ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
468
482
  """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
469
483
  try:
470
- return tensor.numpy()
484
+ return tensor.detach().cpu().numpy()
471
485
  except TypeError:
472
486
  # We try with ml_dtypes
473
487
  pass
@@ -476,7 +490,7 @@ def to_numpy(tensor: "torch.Tensor"): # noqa: F821
476
490
 
477
491
  conv = {torch.bfloat16: ml_dtypes.bfloat16}
478
492
  assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
479
- return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
493
+ return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
480
494
 
481
495
 
482
496
  def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
@@ -765,7 +779,12 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
765
779
 
766
780
 
767
781
  def torch_deepcopy(value: Any) -> Any:
768
- """Makes a deepcopy."""
782
+ """
783
+ Makes a deep copy.
784
+
785
+ :param value: any value
786
+ :return: a deep copy
787
+ """
769
788
  if value is None:
770
789
  return None
771
790
  if isinstance(value, (int, float, str)):
@@ -794,9 +813,14 @@ def torch_deepcopy(value: Any) -> Any:
794
813
  from .cache_helper import CacheKeyValue
795
814
 
796
815
  ca = CacheKeyValue(value)
816
+ if len(ca.key_cache) == 0:
817
+ # Use of deepcopy.
818
+ import copy
819
+
820
+ return copy.deepcopy(value)
797
821
  return make_static_cache(
798
822
  torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
799
- max_cache_len=value.max_cache_len,
823
+ max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
800
824
  )
801
825
  if value.__class__.__name__ == "HybridCache":
802
826
  from .cache_helper import CacheKeyValue
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
3
3
  import numpy as np
4
4
  import onnx
5
5
  import torch
6
- from ..helpers.torch_helper import to_tensor
6
+ from ..helpers.torch_helper import to_tensor, to_numpy
7
7
  from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
8
8
  from .report_results_comparison import ReportResultComparison
9
9
  from . import torch_ops
@@ -578,7 +578,7 @@ class TorchOnnxEvaluator:
578
578
  print(f"- clean {o}")
579
579
 
580
580
  if use_numpy:
581
- return [None if a is None else a.detach().cpu().numpy() for a in fres]
581
+ return [None if a is None else to_numpy(a) for a in fres]
582
582
  return fres
583
583
 
584
584
  def run_with_values(
@@ -0,0 +1,13 @@
1
+ import os
2
+
3
+
4
+ def get_data(name: str):
5
+ """Returns data stored in this folder."""
6
+ filename = os.path.join(os.path.dirname(__file__), name)
7
+ assert os.path.exists(
8
+ filename
9
+ ), f"Unable to find a file with {name!r}, looked for {filename!r}"
10
+
11
+ from ...helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
12
+
13
+ return create_input_tensors_from_onnx_model(filename)