onnx-diagnostic 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -77
  3. onnx_diagnostic/doc.py +68 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/doc_helper.py +27 -7
  8. onnx_diagnostic/helpers/helper.py +30 -3
  9. onnx_diagnostic/helpers/log_helper.py +585 -0
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  11. onnx_diagnostic/helpers/model_builder_helper.py +57 -73
  12. onnx_diagnostic/helpers/onnx_helper.py +291 -7
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +23 -2
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  35. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +174 -114
  40. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +44 -42
  42. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,9 @@ import os
3
3
  import requests
4
4
  import sys
5
5
  from pathlib import Path
6
- from typing import Any, Optional
6
+ from typing import Any, Optional, Union
7
7
  from urllib.parse import urlparse
8
- from onnx import helper, save_model, external_data_helper, ModelProto
8
+ from onnx import ModelProto, TensorProto
9
9
 
10
10
  CACHE_SUBDIR = "onnx-diagnostic"
11
11
 
@@ -114,87 +114,58 @@ def _make_model(self, model, verbose: int = 0):
114
114
  self.make_lm_head(module)
115
115
 
116
116
 
117
- def save_model_builder(self, out_dir: Optional[str] = "", verbose: int = 0) -> ModelProto:
117
+ def save_model_builder(
118
+ self, out_dir: Optional[str] = "", verbose: int = 0
119
+ ) -> Union[str, ModelProto]:
118
120
  """
119
121
  Saves a model created by function :func:`create_model_builder`.
120
122
  If out_dir is empty or not specified, the function still returns the
121
123
  generated model.
122
124
  """
123
- if verbose:
124
- print(f"[save_model_builder] Saving ONNX model in {out_dir}")
125
-
126
- # Create ONNX model
127
- model = helper.make_model(
128
- opset_imports=[
129
- self.clear_field(
130
- helper.make_operatorsetid("", 21 if self.quant_attrs["use_qdq"] else 14),
131
- "domain",
132
- ),
133
- helper.make_operatorsetid("com.microsoft", 1),
134
- ],
135
- ir_version=7,
136
- producer_name="onnxruntime-genai",
137
- producer_version="0.0.0",
138
- graph=self.make_graph(
139
- name="main_graph",
140
- inputs=self.inputs,
141
- outputs=self.outputs,
142
- initializer=self.initializers,
143
- value_info=self.value_infos,
144
- nodes=self.nodes,
145
- ),
146
- )
147
-
148
- # Load external data into ONNX model
149
- external_data_helper.load_external_data_for_model(model, self.cache_dir)
150
-
151
- # Delete external data files on disk before re-saving
152
- for path in os.listdir(self.cache_dir):
153
- if path.endswith(".bin"):
154
- os.remove(os.path.join(self.cache_dir, path))
125
+ import onnx_ir
155
126
 
156
- # Delete temporary cache dir if empty
157
- # if len(os.listdir(self.cache_dir)) == 0:
158
- # os.rmdir(self.cache_dir)
127
+ if verbose:
128
+ print(f"[save_model_builder] Saving ONNX model in {out_dir!r}")
159
129
 
160
- # Quantize ONNX model to desired precision
130
+ # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
161
131
  already_quantized_in_qdq_format = (
162
132
  self.quant_type is not None and self.quant_attrs["use_qdq"]
163
- ) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
164
- if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format:
165
- model = self.to_int4(model)
133
+ )
134
+ model = (
135
+ self.to_int4()
136
+ if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4}
137
+ and not already_quantized_in_qdq_format
138
+ else self.model
139
+ )
140
+ model.graph.sort()
141
+ if not out_dir:
142
+ return onnx_ir.to_proto(model)
166
143
 
167
- # Save ONNX model with only one external data file and delete any existing duplicate copies
168
- if out_dir:
169
- out_path = os.path.join(out_dir, self.filename)
170
- data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
171
- if os.path.exists(out_path):
172
- if verbose:
173
- print(f"[save_model_builder] Overwriting {out_path!r}")
174
- os.remove(out_path)
175
- if os.path.exists(data_path):
176
- if verbose:
177
- print(f"[save_model_builder] Overwriting {data_path!r}")
178
- os.remove(data_path)
144
+ out_path = os.path.join(out_dir, self.filename)
145
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
179
146
 
180
- if out_dir:
181
- location = os.path.basename(data_path)
182
- if os.path.exists(location):
183
- os.remove(location)
147
+ # Save ONNX model with only one external data file and delete any existing duplicate copies
148
+ out_path = os.path.join(out_dir, self.filename)
149
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
150
+ if os.path.exists(out_path):
184
151
  if verbose:
185
- print(f"[save_model_builder] out_path={out_path!r}")
186
- print(f"[save_model_builder] location={location!r}")
187
- save_model(
188
- model,
189
- out_path,
190
- save_as_external_data=True,
191
- all_tensors_to_one_file=True,
192
- location=location,
193
- size_threshold=1024,
194
- convert_attribute=False,
195
- )
196
- return None
197
- return model
152
+ print(f"[save_model_builder] Overwriting {out_path!r}")
153
+ os.remove(out_path)
154
+ if os.path.exists(data_path):
155
+ if verbose:
156
+ print(f"[save_model_builder] Overwriting {data_path!r}")
157
+ os.remove(data_path)
158
+
159
+ onnx_ir.save(
160
+ model,
161
+ out_path,
162
+ external_data=os.path.basename(data_path),
163
+ size_threshold_bytes=2**10,
164
+ )
165
+ if verbose:
166
+ print(f"[save_model_builder] saved in {out_dir!r}")
167
+
168
+ return out_path
198
169
 
199
170
 
200
171
  def create_model_builder(
@@ -220,6 +191,9 @@ def create_model_builder(
220
191
  """
221
192
  assert cache_dir, "create_model_builder does not work without cache_dir."
222
193
  assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists"
194
+ precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get(
195
+ precision, precision
196
+ )
223
197
  download_model_builder_to_cache()
224
198
  builder = import_model_builder()
225
199
  io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)
@@ -332,13 +306,23 @@ def create_model_builder(
332
306
  for c in remove:
333
307
  delattr(config, c)
334
308
 
335
- onnx_model = cls(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
309
+ convert = {
310
+ "fp32": TensorProto.FLOAT,
311
+ "fp16": TensorProto.FLOAT16,
312
+ "bfp16": TensorProto.BFLOAT16,
313
+ }
314
+ assert (
315
+ precision in convert
316
+ ), f"Unexpected value for precision={precision!r}, should be in {convert}"
317
+ onnx_model = cls(
318
+ config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options
319
+ )
336
320
 
337
321
  if post:
338
322
  post(onnx_model)
339
323
  _make_model(onnx_model, model, verbose=verbose)
340
324
 
341
- assert onnx_model.nodes, (
325
+ assert onnx_model.model, (
342
326
  f"No node in the model, io_dtype={io_dtype!r}, "
343
327
  f"precision={precision!r}, execution_provider={execution_provider!r}, "
344
328
  f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, "
@@ -316,7 +316,7 @@ def check_model_ort(
316
316
 
317
317
 
318
318
  @functools.cache
319
- def onnx_dtype_name(itype: int) -> str:
319
+ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
320
320
  """
321
321
  Returns the ONNX name for a specific element type.
322
322
 
@@ -335,7 +335,11 @@ def onnx_dtype_name(itype: int) -> str:
335
335
  v = getattr(TensorProto, k)
336
336
  if v == itype:
337
337
  return k
338
- raise ValueError(f"Unexpected value itype: {itype}")
338
+ if exc:
339
+ raise ValueError(f"Unexpected value itype: {itype}")
340
+ if itype == 0:
341
+ return "UNDEFINED"
342
+ return "UNEXPECTED"
339
343
 
340
344
 
341
345
  def pretty_onnx(
@@ -365,7 +369,7 @@ def pretty_onnx(
365
369
  itype = onx.type.tensor_type.elem_type
366
370
  shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
367
371
  shape_str = ",".join(map(str, shape))
368
- return f"{onnx_dtype_name(itype)}[{shape_str}] {name}"
372
+ return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
369
373
 
370
374
  if isinstance(onx, AttributeProto):
371
375
  att = onx
@@ -767,7 +771,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
767
771
 
768
772
 
769
773
  def iterator_initializer_constant(
770
- model: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
774
+ model: Union[FunctionProto, GraphProto, ModelProto],
771
775
  use_numpy: bool = True,
772
776
  prefix: str = "",
773
777
  ) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821
@@ -779,8 +783,8 @@ def iterator_initializer_constant(
779
783
  :param prefix: for subgraph
780
784
  :return: iterator
781
785
  """
782
- if not isinstance(model, onnx.FunctionProto):
783
- graph = model if isinstance(model, onnx.GraphProto) else model.graph
786
+ if not isinstance(model, FunctionProto):
787
+ graph = model if isinstance(model, GraphProto) else model.graph
784
788
  if not use_numpy:
785
789
  from .torch_helper import to_tensor
786
790
  if prefix:
@@ -791,7 +795,7 @@ def iterator_initializer_constant(
791
795
  )
792
796
  nodes = graph.node
793
797
  name = graph.name
794
- if isinstance(model, onnx.ModelProto):
798
+ if isinstance(model, ModelProto):
795
799
  for f in model.functions:
796
800
  yield from iterator_initializer_constant(
797
801
  f, use_numpy=use_numpy, prefix=f"{prefix}{f.name}"
@@ -908,3 +912,283 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
908
912
  qu = np.quantile(tensor, ii)
909
913
  stat.update({f"q{i}": float(q) for i, q in zip(ii, qu)})
910
914
  return stat
915
+
916
+
917
+ class NodeCoordinates:
918
+ """
919
+ A way to localize a node,
920
+ path is a tuple of three information, node index, node type, node name.
921
+ """
922
+
923
+ __slots__ = ("node", "path")
924
+
925
+ def __init__(
926
+ self,
927
+ node: Union[onnx.TensorProto, NodeProto, str],
928
+ path: Tuple[Tuple[int, str, str], ...],
929
+ ):
930
+ assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
931
+ assert all(isinstance(t, tuple) for t in path), f"Unexpected type in path={path}"
932
+ self.node = node
933
+ self.path = path
934
+
935
+ def __str__(self) -> str:
936
+ "usual"
937
+ if isinstance(self.node, str):
938
+ return f"{self.path_to_str()} :: {self.node!r}"
939
+ return f"{self.path_to_str()} :: {pretty_onnx(self.node)}"
940
+
941
+ def path_to_str(self) -> str:
942
+ "Strings representing coordinates."
943
+ return "x".join(f"({':'.join(map(str, t))})" for t in self.path)
944
+
945
+
946
+ class ResultFound:
947
+ """
948
+ Class returned by :func:`enumerate_results`.
949
+ """
950
+
951
+ __slots__ = ("consumer", "name", "producer")
952
+
953
+ def __init__(
954
+ self,
955
+ name: str,
956
+ producer: Optional[NodeCoordinates],
957
+ consumer: Optional[NodeCoordinates],
958
+ ):
959
+ assert isinstance(name, str), f"unexpected type {type(name)} for name"
960
+ self.name = name
961
+ self.producer = producer
962
+ self.consumer = consumer
963
+
964
+ def __str__(self) -> str:
965
+ "usuals"
966
+ return (
967
+ f"<< {self.name} - {self.consumer}"
968
+ if self.producer is None
969
+ else f">> {self.name} - {self.producer}"
970
+ )
971
+
972
+
973
+ def enumerate_results(
974
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
975
+ name: Union[Set[str], str],
976
+ verbose: int = 0,
977
+ coordinates: Optional[List[Tuple[int, str, str]]] = None,
978
+ ) -> Iterator[ResultFound]:
979
+ """
980
+ Iterates on all nodes, attributes to find where a name is used.
981
+
982
+ :param proto: a proto
983
+ :param name: name or names to find
984
+ :param verbose: verbosity
985
+ :param coordinates: coordinates of a node
986
+ :return: iterator on :class:`ResultFound`
987
+ """
988
+ if not isinstance(name, set):
989
+ name = {name}
990
+ coordinates = coordinates or []
991
+ assert all(
992
+ isinstance(c, tuple) for c in coordinates
993
+ ), f"Unexpected type in coordinates={coordinates}"
994
+ indent = " " * len(coordinates)
995
+ if isinstance(proto, ModelProto):
996
+ if verbose:
997
+ print(f"[enumerate_results] {indent}searching for {name!r} into ModelProto...")
998
+ yield from enumerate_results(proto.graph, name, verbose=verbose)
999
+ elif isinstance(proto, FunctionProto):
1000
+ if verbose:
1001
+ print(f"[enumerate_results] {indent}searching for {name!r} into FunctionProto...")
1002
+ for i in proto.input:
1003
+ if i in name:
1004
+ r = ResultFound(
1005
+ i,
1006
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1007
+ None,
1008
+ )
1009
+ if verbose > 1:
1010
+ print(f"[enumerate_results] {indent}-- {r}")
1011
+ yield r
1012
+ yield from enumerate_results(proto.node, name, verbose=verbose)
1013
+ for i in proto.output:
1014
+ if i in name:
1015
+ r = ResultFound(
1016
+ i,
1017
+ None,
1018
+ NodeCoordinates(
1019
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1020
+ ),
1021
+ )
1022
+ if verbose > 1:
1023
+ print(f"[enumerate_results] {indent}-- {r}")
1024
+ yield r
1025
+ elif isinstance(proto, GraphProto):
1026
+ if verbose:
1027
+ print(f"[enumerate_results] {indent}searching for {name!r} into GraphProto...")
1028
+ for i in proto.initializer:
1029
+ if i.name in name:
1030
+ r = ResultFound(
1031
+ i.name,
1032
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1033
+ None,
1034
+ )
1035
+ if verbose > 1:
1036
+ print(f"[enumerate_results] {indent}-- {r}")
1037
+ yield r
1038
+ for i in proto.sparse_initializer:
1039
+ if i.name in name:
1040
+ r = ResultFound(
1041
+ i.name,
1042
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
1043
+ None,
1044
+ )
1045
+ if verbose > 1:
1046
+ print(f"[enumerate_results] {indent}-- {r}")
1047
+ yield r
1048
+ for i in proto.input:
1049
+ if i.name in name:
1050
+ r = ResultFound(
1051
+ i.name,
1052
+ NodeCoordinates(i, tuple([*coordinates, (-1, "INPUT", "")])), # noqa: C409
1053
+ None,
1054
+ )
1055
+ if verbose > 1:
1056
+ print(f"[enumerate_results] {indent}-- {r}")
1057
+ yield r
1058
+ yield from enumerate_results(
1059
+ proto.node, name, verbose=verbose, coordinates=coordinates
1060
+ )
1061
+ for i in proto.output:
1062
+ if i.name in name:
1063
+ r = ResultFound(
1064
+ i.name,
1065
+ None,
1066
+ NodeCoordinates(
1067
+ i, tuple([*coordinates, (len(proto.node), "OUTPUT", "")]) # noqa: C409
1068
+ ),
1069
+ )
1070
+ if verbose > 1:
1071
+ print(f"[enumerate_results] {indent}-- {r}")
1072
+ yield r
1073
+ else:
1074
+ if verbose:
1075
+ print(
1076
+ f"[enumerate_results] {indent}searching for {name!r} into List[NodeProto]..."
1077
+ )
1078
+ for node_i, node in enumerate(proto):
1079
+ if set(node.input) & name:
1080
+ for n in node.input:
1081
+ if n in name:
1082
+ r = ResultFound(
1083
+ n,
1084
+ NodeCoordinates(
1085
+ node,
1086
+ tuple( # noqa: C409
1087
+ [*coordinates, (node_i, node.op_type, node.name)]
1088
+ ),
1089
+ ),
1090
+ None,
1091
+ )
1092
+ if verbose > 1:
1093
+ print(f"[enumerate_results] {indent}-- {r}")
1094
+ yield r
1095
+ if node.op_type in {"If", "Scan", "Loop", "SequenceMap"}:
1096
+ for att in node.attribute:
1097
+ if att.type == onnx.AttributeProto.GRAPH:
1098
+ yield from enumerate_results(
1099
+ att.g,
1100
+ name,
1101
+ verbose=verbose,
1102
+ coordinates=[*coordinates, (node_i, node.op_type, node.name)],
1103
+ )
1104
+ if set(node.output) & name:
1105
+ for n in node.output:
1106
+ if n in name:
1107
+ r = ResultFound(
1108
+ n,
1109
+ None,
1110
+ NodeCoordinates(
1111
+ node,
1112
+ tuple( # noqa: C409
1113
+ [*coordinates, (node_i, node.op_type, node.name)]
1114
+ ),
1115
+ ),
1116
+ )
1117
+ if verbose > 1:
1118
+ print(f"[enumerate_results] {indent}-- {r}")
1119
+ yield r
1120
+ if verbose:
1121
+ print(f"[enumerate_results] {indent}done")
1122
+
1123
+
1124
+ def shadowing_names(
1125
+ proto: Union[FunctionProto, GraphProto, ModelProto, Sequence[NodeProto]],
1126
+ verbose: int = 0,
1127
+ existing: Optional[Set[str]] = None,
1128
+ shadow_context: Optional[Set[str]] = None,
1129
+ post_shadow_context: Optional[Set[str]] = None,
1130
+ ) -> Tuple[Set[str], Set[str], Set[str]]:
1131
+ """
1132
+ Returns the shadowing names, the names created in the main graph
1133
+ after they were created in a subgraphs and the names created by the nodes.
1134
+ """
1135
+ if isinstance(proto, ModelProto):
1136
+ return shadowing_names(proto.graph)
1137
+ if isinstance(proto, GraphProto):
1138
+ assert (
1139
+ existing is None and shadow_context is None
1140
+ ), "existing must be None if nodes is None"
1141
+ return shadowing_names(
1142
+ proto.node,
1143
+ verbose=verbose,
1144
+ existing=set(i.name for i in proto.initializer)
1145
+ | set(i.name for i in proto.sparse_initializer)
1146
+ | set(i.name for i in proto.input if i.name),
1147
+ shadow_context=set(),
1148
+ post_shadow_context=set(),
1149
+ )
1150
+ if isinstance(proto, FunctionProto):
1151
+ assert (
1152
+ existing is None and shadow_context is None
1153
+ ), "existing must be None if nodes is None"
1154
+ return shadowing_names(
1155
+ proto.node,
1156
+ verbose=verbose,
1157
+ existing=set(i for i in proto.input if i),
1158
+ shadow_context=set(),
1159
+ post_shadow_context=set(),
1160
+ )
1161
+
1162
+ assert (
1163
+ existing is not None and shadow_context is not None
1164
+ ), "existing must not be None if nodes is not None"
1165
+ shadow = set()
1166
+ shadow_context = shadow_context.copy()
1167
+ existing = existing.copy()
1168
+ created = set()
1169
+ post_shadow = set()
1170
+ for node in proto:
1171
+ not_empty = set(n for n in node.input if n)
1172
+ intersection = not_empty & existing
1173
+ assert len(intersection) == len(not_empty), (
1174
+ f"One input in {not_empty}, node={pretty_onnx(node)} "
1175
+ f"was not found in {existing}"
1176
+ )
1177
+ for att in node.attribute:
1178
+ if att.type == AttributeProto.GRAPH:
1179
+ g = att.g
1180
+ shadow |= set(i.name for i in g.input) & shadow_context
1181
+ shadow |= set(i.name for i in g.initializer) & shadow_context
1182
+ shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
1183
+ s, ps, c = shadowing_names(
1184
+ g.node, verbose=verbose, existing=existing, shadow_context=existing
1185
+ )
1186
+ shadow |= s
1187
+ created |= c
1188
+
1189
+ not_empty = set(n for n in node.output if n)
1190
+ post_shadow |= not_empty & created
1191
+ shadow |= not_empty & shadow_context
1192
+ existing |= not_empty
1193
+ created |= not_empty
1194
+ return shadow, post_shadow, created
@@ -16,6 +16,7 @@ from .cache_helper import (
16
16
  make_encoder_decoder_cache,
17
17
  make_sliding_window_cache,
18
18
  make_mamba_cache,
19
+ make_static_cache,
19
20
  )
20
21
  from .mini_onnx_builder import create_onnx_model_from_input_tensors
21
22
  from .onnx_helper import (
@@ -288,7 +289,8 @@ def steal_forward(
288
289
  """
289
290
  The necessary modification to steem forward method and prints out inputs
290
291
  and outputs using :func:`onnx_diagnostic.helpers.string_type`.
291
- See example :ref:`l-plot-tiny-llm-export`.
292
+ See example :ref:`l-plot-tiny-llm-export` or
293
+ :ref:`l-plot-intermediate-results`.
292
294
 
293
295
  :param model: a model or a list of models to monitor,
294
296
  every model can also be a tuple(name, model), name is displayed well.
@@ -410,12 +412,15 @@ def steal_forward(
410
412
  proto = create_onnx_model_from_input_tensors(storage)
411
413
  if verbose:
412
414
  print("-- dumps stored objects")
415
+ location = f"{os.path.split(dump_file)[-1]}.data"
416
+ if os.path.exists(location):
417
+ os.remove(location)
413
418
  onnx.save(
414
419
  proto,
415
420
  dump_file,
416
421
  save_as_external_data=True,
417
422
  all_tensors_to_one_file=True,
418
- location=f"{os.path.split(dump_file)[-1]}.data",
423
+ location=location,
419
424
  )
420
425
  if verbose:
421
426
  print("-- done dump stored objects")
@@ -723,6 +728,15 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
723
728
  )
724
729
  )
725
730
  )
731
+ if value.__class__.__name__ == "StaticCache":
732
+ return make_static_cache(
733
+ list(
734
+ zip(
735
+ [t.to(to_value) for t in value.key_cache],
736
+ [t.to(to_value) for t in value.value_cache],
737
+ )
738
+ )
739
+ )
726
740
  if value.__class__.__name__ == "EncoderDecoderCache":
727
741
  return make_encoder_decoder_cache(
728
742
  to_any(value.self_attention_cache, to_value),
@@ -769,6 +783,8 @@ def torch_deepcopy(value: Any) -> Any:
769
783
  return make_dynamic_cache(
770
784
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
771
785
  )
786
+ if value.__class__.__name__ == "StaticCache":
787
+ return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
772
788
  if value.__class__.__name__ == "SlidingWindowCache":
773
789
  return make_sliding_window_cache(
774
790
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
@@ -1,3 +1,4 @@
1
1
  from .evaluator import ExtendedReferenceEvaluator
2
2
  from .ort_evaluator import OnnxruntimeEvaluator
3
3
  from .torch_evaluator import TorchOnnxEvaluator
4
+ from .report_results_comparison import ReportResultComparison