onnx-diagnostic 0.8.11__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  3. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +8 -3
  6. onnx_diagnostic/export/api.py +11 -0
  7. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  8. onnx_diagnostic/helpers/cache_helper.py +96 -30
  9. onnx_diagnostic/helpers/helper.py +39 -0
  10. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  11. onnx_diagnostic/helpers/ort_session.py +5 -1
  12. onnx_diagnostic/helpers/rt_helper.py +53 -9
  13. onnx_diagnostic/helpers/torch_helper.py +7 -2
  14. onnx_diagnostic/investigate/input_observer.py +793 -152
  15. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  16. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  17. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  18. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  19. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +28 -2
  20. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  21. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +24 -21
  22. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  23. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  24. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,21 @@
1
1
  import contextlib
2
2
  import inspect
3
+ import time
3
4
  from typing import Any, Callable, Sequence
5
+ import onnx
4
6
  import torch
7
+ from ..helpers import max_diff, string_type
8
+ from ..reference import OnnxruntimeEvaluator
5
9
 
10
+ EOL = "\n"
6
11
 
7
- def flatten_unflatten_for_dynamic_shapes(
12
+
13
+ def _flatten_unflatten_for_dynamic_shapes(
8
14
  obj: Any,
9
15
  use_dict: bool = True,
10
16
  change_function: Callable[[torch.Tensor], Any] | None = None,
11
17
  ) -> Any:
12
- """
13
- Returns the object in a different structure similar to what
18
+ """Returns the object in a different structure similar to what
14
19
  the definition of the dynamic shapes should use.
15
20
 
16
21
  Args:
@@ -27,7 +32,7 @@ def flatten_unflatten_for_dynamic_shapes(
27
32
  like replace them by a shape
28
33
 
29
34
  Returns:
30
- the serialized object
35
+ the flattened object
31
36
  """
32
37
  if isinstance(obj, torch.Tensor):
33
38
  return change_function(obj) if change_function else obj
@@ -38,7 +43,7 @@ def flatten_unflatten_for_dynamic_shapes(
38
43
  for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
39
44
  end += subspec.num_leaves
40
45
  value = subspec.unflatten(flat[start:end])
41
- value = flatten_unflatten_for_dynamic_shapes(
46
+ value = _flatten_unflatten_for_dynamic_shapes(
42
47
  value, use_dict=use_dict, change_function=change_function
43
48
  )
44
49
  subtrees.append(value)
@@ -66,157 +71,435 @@ def flatten_unflatten_for_dynamic_shapes(
66
71
  return subtrees
67
72
 
68
73
 
69
- def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]:
70
- """
71
- Returns the list of dynamic dimensions given a list of shapes
74
+ def _infer_dynamic_dimensions(
75
+ shape_list: Sequence[tuple[int, ...]], set_batch_dimension: bool = False
76
+ ) -> list[int]:
77
+ """Returns the list of dynamic dimensions given a list of shapes
72
78
  corresponding to the same tensor.
73
79
 
74
80
  Args:
75
81
  shape_list:
76
82
  list of shapes, they must all have the same length
83
+ set_batch_dimension:
84
+ forces the first dimension to be treated as dynamic,
85
+ even if all shapes have the same value for that dimension
77
86
 
78
87
  Returns:
79
88
  list of dynamic dimensions
80
89
  """
81
90
  unique_ranks = {len(shape) for shape in shape_list}
82
91
  torch._check(
83
- len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank"
92
+ len(unique_ranks) == 1,
93
+ lambda: "all shapes in shape_list must have the same rank",
84
94
  )
85
95
  rank = unique_ranks.pop()
86
96
  dynamic = []
87
97
  for i in range(rank):
88
98
  dims = [shape[i] for shape in shape_list]
89
- if len(set(dims)) > 1:
99
+ if len(set(dims)) > 1 or (i == 0 and set_batch_dimension):
90
100
  dynamic.append(i)
91
101
  return dynamic
92
102
 
93
103
 
104
+ class InputCandidate:
105
+ """Retains one set of inputs given to the forward method or any
106
+ other method the class :class:`InputObserver` is stealing from.
107
+ Any class is allowed as long as it can be flattened.
108
+
109
+ Args:
110
+ args:
111
+ Positional arguments.
112
+ kwargs:
113
+ Optional arguments.
114
+ clone:
115
+ Clones the inputs before storing them. Some tensors
116
+ may be modified inplace, the original value must be retained.
117
+ cst_kwargs:
118
+ Any optional arguments constant over multiple calls.
119
+ int, float, str, bool values must be stored here.
120
+
121
+ The constructor flattens the received arguments.
122
+ Any necessary flattening function should have been registered first.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ args: tuple[Any, ...],
128
+ kwargs: dict[str, Any],
129
+ clone: bool,
130
+ cst_kwargs: dict[str, int | str | float | bool],
131
+ ):
132
+ self.args = args
133
+ self.kwargs = kwargs
134
+ self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs))
135
+ self.n_tensors = sum(t is not None for t in self.flat_list)
136
+ self._position_to_args_kwargs: list[int | str] | None = None
137
+ self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None
138
+ self.cst_kwargs = cst_kwargs.copy()
139
+ assert "use_cache" not in self.cst_kwargs
140
+
141
+ if clone:
142
+ self.flat_list = [
143
+ (None if not isinstance(t, torch.Tensor) else t.clone().detach())
144
+ for t in self.flat_list
145
+ ]
146
+ self.args, self.kwargs = torch.utils._pytree.tree_unflatten(
147
+ self.flat_list, self.spec
148
+ )
149
+
150
+ self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None
151
+ self.aligned_flat_list: list[torch.Tensor | None] | None = None
152
+
153
+ def __str__(self) -> str:
154
+ return (
155
+ f"{self.__class__.__name__}({len(self.args)} args, "
156
+ f"{len(self.kwargs)} kwargs, {len(self.flat_list)} tensors, "
157
+ f"{len(self.aligned_flat_list or [])} aligned tensors)"
158
+ )
159
+
160
+ def __len__(self) -> int:
161
+ """Returns the number of flattended tensors, None tensors are included."""
162
+ return len(self.flat_list)
163
+
164
+ def str_obs(self) -> str:
165
+ """Prints out some information about the osbervations."""
166
+ return (
167
+ f"InputCandidate(args={string_type(self.args, with_shape=True)}, "
168
+ f"kwargs={string_type(self.kwargs, with_shape=True)}, "
169
+ f"cst_kwargs={self.cst_kwargs})"
170
+ )
171
+
172
+ def build_mappings(self) -> list[int | str]:
173
+ if self._position_to_args_kwargs is not None:
174
+ return self._position_to_args_kwargs
175
+ self._n_tensors_for_args_kwargs = {}
176
+
177
+ flat_index_to_args: list[int | str] = []
178
+ for index_args, a in enumerate(self.args):
179
+ size = len(torch.utils._pytree.tree_flatten(a)[0])
180
+ self._n_tensors_for_args_kwargs[index_args] = size
181
+ flat_index_to_args.extend([index_args] * size)
182
+ for k, v in self.kwargs.items():
183
+ size = len(torch.utils._pytree.tree_flatten(v)[0])
184
+ self._n_tensors_for_args_kwargs[k] = size
185
+ flat_index_to_args.extend([k] * size)
186
+
187
+ self._position_to_args_kwargs = flat_index_to_args
188
+ return self._position_to_args_kwargs
189
+
190
+ @property
191
+ def position_to_args_kwargs(self) -> list[int | str]:
192
+ """Returns the corresponding args or kwargs
193
+ for every tensor in the flattened inputs.
194
+ """
195
+ if self._position_to_args_kwargs is None:
196
+ self.build_mappings()
197
+ # type checking is missing it
198
+ assert self._position_to_args_kwargs is not None
199
+ return self._position_to_args_kwargs
200
+
201
+ @property
202
+ def n_tensors_for_args_kwargs(self) -> dict[int | str, int]:
203
+ """Returns the number of flat tensors in every args or kwargs."""
204
+ if self._n_tensors_for_args_kwargs is None:
205
+ self.build_mappings()
206
+ # type checking is missing it
207
+ assert self._n_tensors_for_args_kwargs is not None
208
+ return self._n_tensors_for_args_kwargs
209
+
210
+ def _set_aligned_flat_list(
211
+ self,
212
+ aligned_flat_list: list[torch.Tensor | None],
213
+ aligned_spec: torch.utils._pytree.PyTreeSpec,
214
+ ):
215
+ self.aligned_flat_list = aligned_flat_list
216
+ self.aligned_spec = aligned_spec
217
+
218
+ def align_with(
219
+ self,
220
+ best_candidate: "InputCandidate",
221
+ captured_inputs: dict[int | str, int],
222
+ signature_names: list[str],
223
+ ):
224
+ """Two candidates are considered as aligned if after being flattened
225
+ if they have the same number of tensors (None allowed)."""
226
+ if self.cst_kwargs != best_candidate.cst_kwargs:
227
+ raise RuntimeError(
228
+ f"Two calls were made with different constant values, "
229
+ f"{self.cst_kwargs} != {best_candidate.cst_kwargs}"
230
+ )
231
+
232
+ args = self.args
233
+ if len(self.args) > len(best_candidate.args):
234
+ # We need to move some args to kwargs as the best_candidate does.
235
+ new_kwargs = {}
236
+ for i in range(len(best_candidate.args), len(self.args)):
237
+ new_kwargs[signature_names[i]] = args[i]
238
+ args = args[: len(best_candidate.args)]
239
+ kwargs = {**new_kwargs, **self.kwargs}
240
+ else:
241
+ kwargs = self.kwargs
242
+
243
+ flat = []
244
+ for i in range(len(best_candidate.args)):
245
+ if i < len(args) and (isinstance(args[i], torch.Tensor) or args[i]):
246
+ ts = torch.utils._pytree.tree_flatten(self.args[i])[0]
247
+ if i in captured_inputs and captured_inputs[i] != len(ts):
248
+ raise RuntimeError(
249
+ f"Positional argument {i} has {len(ts)} tensors "
250
+ f"but previously got {captured_inputs[i]} tensors. "
251
+ f"Inference is impossible in that case."
252
+ )
253
+ captured_inputs[i] = len(ts)
254
+ flat.extend(ts)
255
+ continue
256
+ # If the argument i is not specified or is None or an empty container.
257
+ flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[i])])
258
+
259
+ for k in best_candidate.kwargs:
260
+ if k in kwargs and (isinstance(kwargs[k], torch.Tensor) or kwargs[k]):
261
+ ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
262
+ if k in captured_inputs and captured_inputs[k] != len(ts):
263
+ raise RuntimeError(
264
+ f"Named argument {k!r} has {len(ts)} tensors "
265
+ f"but previously got {captured_inputs[k]} tensors in "
266
+ f"kwargs={list(kwargs)}. "
267
+ f"Inference is impossible in that case."
268
+ )
269
+ captured_inputs[k] = len(ts)
270
+ flat.extend(ts)
271
+ continue
272
+ # If the argument k is not specified or is None or an empty container.
273
+ flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[k])])
274
+
275
+ self._set_aligned_flat_list(flat, best_candidate.spec)
276
+
277
+ @property
278
+ def n_aligned_tensors(self) -> int:
279
+ if self.aligned_flat_list is None:
280
+ raise RuntimeError("This input was not aligned with the others.")
281
+ return len(self.aligned_flat_list)
282
+
283
+
94
284
  class InputObserverInfo:
95
- def __init__(self, signature: inspect.Signature):
96
- # pyrefly: ignore
97
- self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
98
- self.flat_inputs: list[list[torch.Tensor | None]] = []
99
- # pyrefly: ignore
100
- self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
101
- self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = []
102
- self.signature = signature
285
+ """Contains all the necessary information to infer dynamic shapes
286
+ and the arguments to send to :func:`torch.export.export`.
287
+
288
+ Args:
289
+ signature_names: Names of the arguments of the method
290
+ the collector tensors come from. They are used if it becomes
291
+ necessary to move positional arguments to named ones.
292
+ They are used a second time because :func:`torch.export.export`
293
+ cares about the order in kwargs and dynamic shapes, it needs
294
+ to be the same in the ordered dictionaries `add_inputs` receive.
295
+ default_values: Default values defined by the signature of the function,
296
+ any value equal to that is ignore to simplify the export.
297
+ missing: If a named argument (in kwargs) is missing,
298
+ a default value will be taken in this dictionary,
299
+ this is used when after the prefill step, an argument
300
+ disappears (such as `pixel_values`) and another one
301
+ is added (such as `past_key_values`).
302
+ The values are only to infer dynamic shapes and arguments,
303
+ not to run the model.
304
+ """
103
305
 
104
- self._max_args: tuple[Any, torch.Tensor] | None = None
105
- self._max_kwargs: dict[str, torch.Tensor] | None = None
306
+ def __init__(
307
+ self,
308
+ signature_names: list[str],
309
+ default_values: dict[str, int | bool | str | float],
310
+ missing: dict[str, Any],
311
+ ):
312
+ self.default_values = default_values
313
+ self.missing = missing
314
+ self.inputs: list[InputCandidate] = []
315
+ self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
316
+ self.flat_outputs: list[list[torch.Tensor | None]] = []
317
+ self.latencies: list[float] = []
318
+ self.signature_names = signature_names
319
+ self._best_candidate: InputCandidate | None = None
320
+ self._captured_inputs: dict[int | str, int] | None = None
106
321
 
107
322
  def __len__(self) -> int:
108
- return len(self.flat_inputs)
323
+ """Returns the number of collected set of inputs/outputs."""
324
+ return len(self.inputs)
109
325
 
110
326
  def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
327
+ """Stores one set of inputs. They are deepcopied.
328
+
329
+ Args:
330
+ args: Positional arguments.
331
+ kwargs: Named arguments.
332
+ """
333
+ cst_kwargs = {
334
+ k: v
335
+ for k, v in kwargs.items()
336
+ if k in self.signature_names
337
+ and isinstance(v, (int, float, bool, str))
338
+ and v != self.default_values.get(k, None)
339
+ and self.default_values.get(k, None) is not None
340
+ }
111
341
  kwargs = {
112
342
  k: v
113
343
  for k, v in kwargs.items()
114
344
  if v is not None and not isinstance(v, (int, float, bool))
115
345
  }
116
- flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
117
- self.inputs_specs.append(spec)
118
- cloned = [
119
- (None if not isinstance(t, torch.Tensor) else t.clone().detach())
120
- for t in flat_args
121
- ]
122
- self.flat_inputs.append(cloned)
123
346
 
124
- cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)
125
- if self._max_args is None or len(cloned_args) > len(self._max_args):
126
- self._max_args = cloned_args
127
- if self._max_kwargs is None or len(cloned_kwargs) > len(self._max_kwargs):
128
- self._max_kwargs = cloned_kwargs
347
+ # adds missing attributes
348
+ for k, v in self.missing.items():
349
+ if k not in kwargs:
350
+ kwargs[k] = v
351
+
352
+ # kwargs may come in a different ordeer teach.
353
+ # dictionaries are ordered and torch.export.export expects
354
+ # dynamic shapes an kwargs to follow the same order.
355
+
356
+ ordered_kwargs = {k: kwargs[k] for k in self.signature_names if k in kwargs}
357
+ for k, v in kwargs.items():
358
+ if k not in ordered_kwargs:
359
+ ordered_kwargs[k] = v
360
+
361
+ candidate = InputCandidate(args, ordered_kwargs, clone=True, cst_kwargs=cst_kwargs)
362
+ self.inputs.append(candidate)
363
+ if self._best_candidate is None or len(self._best_candidate) < len(candidate):
364
+ self._best_candidate = candidate
129
365
 
130
- def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
366
+ def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...], latency: float):
367
+ """Stores outputs. They are deepcopied."""
131
368
  flat_res, spec = torch.utils._pytree.tree_flatten(res)
132
369
  self.outputs_specs.append(spec)
133
- self.flat_outputs.append([t.clone().detach() for t in flat_res])
370
+ self.flat_outputs.append(
371
+ [(None if t is None else t.clone().detach()) for t in flat_res]
372
+ )
373
+ self.latencies.append(latency)
134
374
 
135
- def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]:
136
- # Let's compute the sizes of each independently.
137
- if not self.flat_inputs or self._max_args is None or self._max_kwargs is None:
375
+ def align_inputs_none_values(self):
376
+ """Once the best candidate is chosen, this method aligns every set of inputs
377
+ on the best candidate, it inserts None at the right position when
378
+ optional inputs are not specified. We consider a set of inputs is aligned
379
+ if this method does not change the original flattened inputs.
380
+ """
381
+ if not self.inputs or self._best_candidate is None:
138
382
  raise RuntimeError("No inputs were captured.")
139
- arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args]
140
- kwarg_sizes = {
141
- k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items()
142
- }
383
+
384
+ if all(candidate.aligned_flat_list is not None for candidate in self.inputs):
385
+ # No new inputs, no alignment is necessary.
386
+ return
143
387
 
144
388
  # Let's reprocess everything.
145
- captured_inputs: dict[int | str, int] = {}
146
- new_flat_inputs = []
147
- for args_kwargs, spec in zip(self.flat_inputs, self.inputs_specs):
148
- args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
149
- if len(set(kwargs) | set(self._max_kwargs)) > len(self._max_kwargs):
389
+ self._captured_inputs = {}
390
+ for candidate in self.inputs:
391
+ if len(set(candidate.kwargs) | set(self._best_candidate.kwargs)) > len(
392
+ self._best_candidate.kwargs
393
+ ):
150
394
  raise RuntimeError(
151
- "At least one call to the observed model "
152
- "must contain all the named arguments."
395
+ f"At least one call to the observed model "
396
+ f"must contain all the named arguments. "
397
+ f"candidate kwargs={list(candidate.kwargs)}, "
398
+ f"best candidate kwargs={list(self._best_candidate.kwargs)}, "
399
+ f"all candidate kwargs={EOL}"
400
+ f"{EOL.join(string_type(c.kwargs, with_shape=True) for c in self.inputs)}"
153
401
  )
154
- flat = []
155
- for i in range(len(self._max_args)):
156
- if i < len(args):
157
- ts = torch.utils._pytree.tree_flatten(args[i])[0]
158
- if i in captured_inputs and captured_inputs[i] != len(ts):
159
- raise RuntimeError(
160
- f"Positional argument {i} has {len(ts)} tensors "
161
- f"but previously got {captured_inputs[i]} tensors. "
162
- f"Inference is impossible in that case."
163
- )
164
- captured_inputs[i] = len(ts)
165
- flat.extend(ts)
166
- else:
167
- flat.extend([None for _ in range(arg_sizes[i])])
168
- for k in self._max_kwargs:
169
- if k in kwargs:
170
- ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
171
- if k in captured_inputs and captured_inputs[k] != len(ts):
172
- raise RuntimeError(
173
- f"Named argument {k!r} has {len(ts)} tensors "
174
- f"but previously got {captured_inputs[k]} tensors. "
175
- f"Inference is impossible in that case."
176
- )
177
- captured_inputs[k] = len(ts)
178
- flat.extend(ts)
179
- else:
180
- flat.extend([None for _ in range(kwarg_sizes[k])])
181
- new_flat_inputs.append(flat)
182
- return new_flat_inputs
183
-
184
- def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
185
- flat_inputs = self.build_inputs_completed_with_none_values()
186
- # This is already checked by build_inputs_completed_with_none_values
187
- # but this is not always well captured by tools checking types.
188
- assert self._max_args is not None and self._max_kwargs is not None
189
- if len({len(flat) for flat in flat_inputs}) != 1:
402
+ candidate.align_with(
403
+ self._best_candidate, self._captured_inputs, self.signature_names
404
+ )
405
+
406
+ def infer_dynamic_shapes(
407
+ self,
408
+ set_batch_dimension_for: set[int | str] | bool | None = None,
409
+ return_flat: bool = False,
410
+ ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
411
+ """Infers dynamic shapes based on the collected tensors.
412
+ Most of the time, models do support a batch dimension
413
+ but this batch dimension has the same value for every input sample.
414
+ Instead of running inference on new samples, argument `set_batch_dimension_for`
415
+ can be used to tell the first dimension is a dynamic dimension for a particular
416
+ set of inputs referenced by their name (str) or their position (int).
417
+
418
+ Args:
419
+ set_batch_dimension_for (set[int | str] | None): Set of input identifiers,
420
+ by name (``str``) or position (``int``), for which the first dimension
421
+ should be treated as a dynamic batch dimension. If ``None`` or empty,
422
+ no additional batch dimensions are marked as dynamic.
423
+ return_flat: Tells the function to return a flat tuple instead of
424
+ nested structured.
425
+ """
426
+ self.align_inputs_none_values()
427
+ # type checking
428
+ assert self._best_candidate is not None
429
+ assert self._best_candidate.flat_list is not None
430
+ assert self._best_candidate.aligned_flat_list is not None
431
+
432
+ def _set_batch_dimension(name_or_position):
433
+ if not set_batch_dimension_for:
434
+ return False
435
+ if (
436
+ isinstance(set_batch_dimension_for, bool) and set_batch_dimension_for
437
+ ) or name_or_position in set_batch_dimension_for:
438
+ return True
439
+ if isinstance(name_or_position, int):
440
+ torch._check(
441
+ name_or_position < len(self.signature_names),
442
+ lambda: f"argument at position {name_or_position} is out of boundary",
443
+ )
444
+ if self.signature_names[name_or_position] in set_batch_dimension_for:
445
+ return True
446
+ return False
447
+
448
+ def _set_batch_dimension_for_flat_index(index):
449
+ # type checking
450
+ assert self._best_candidate is not None
451
+ return _set_batch_dimension(self._best_candidate.position_to_args_kwargs[index])
452
+
453
+ if len(self._best_candidate.flat_list) != len(self._best_candidate.aligned_flat_list):
190
454
  raise NotImplementedError(
191
455
  "infer_dynamic_shapes is not implemented "
192
- "when the number of input tensors are not the same."
456
+ "when the best candidate is not 'aligned'."
457
+ "This happens when there is not stored set inputs where "
458
+ "all optional inputs showing in other sets are defined."
459
+ )
460
+
461
+ if len({inputs.n_aligned_tensors for inputs in self.inputs}) != 1:
462
+ raise NotImplementedError(
463
+ f"infer_dynamic_shapes is not implemented "
464
+ f"when the number of input tensors are not the same in "
465
+ f"every set of inputs "
466
+ f"{[inputs.n_aligned_tensors for inputs in self.inputs]}."
193
467
  )
194
468
  shape_lists = [
195
- [(None if t is None else t.shape) for t in tensors] for tensors in flat_inputs
469
+ [(None if t is None else t.shape) for t in candidate.aligned_flat_list]
470
+ for candidate in self.inputs
471
+ if candidate.aligned_flat_list is not None
196
472
  ]
197
473
  n_tensors = len(shape_lists[0])
198
474
  dynamic_shapes = [
199
- infer_dynamic_dimensions(
200
- [s for s in [shapes[index] for shapes in shape_lists] if s is not None]
475
+ _infer_dynamic_dimensions(
476
+ [s for s in [shapes[index] for shapes in shape_lists] if s is not None],
477
+ set_batch_dimension=_set_batch_dimension_for_flat_index(index),
201
478
  )
202
479
  for index in range(n_tensors)
203
480
  ]
204
481
  cst = torch.export.Dim.DYNAMIC
205
482
  flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
206
- if len(flat_dynamic_shapes) == len(self._max_args) + len(self._max_kwargs):
483
+ if return_flat:
484
+ return tuple(flat_dynamic_shapes)
485
+ if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len(
486
+ self._best_candidate.kwargs
487
+ ):
207
488
  # It means forward method is called with tensors only.
208
- if not self._max_kwargs:
489
+ if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs:
209
490
  # only positional arguments
210
491
  return tuple(flat_dynamic_shapes)
211
- if not self._max_args:
492
+ if not self._best_candidate.args:
212
493
  # only named arguments
213
- return dict(zip(list(self._max_kwargs), flat_dynamic_shapes))
494
+ ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes))
495
+ return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
214
496
  # positional arguments needs to be moved to the named arguments
215
- n_args = len(self._max_args)
216
- pos_names = list(self.signature.parameters)[:n_args]
497
+ n_args = len(self._best_candidate.args)
498
+ pos_names = self.signature_names[:n_args]
217
499
  return {
218
500
  **dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
219
- **dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])),
501
+ **dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])),
502
+ **dict.fromkeys(self._best_candidate.cst_kwargs, None),
220
503
  }
221
504
 
222
505
  # nested types, here comes the fun part because the shapes cannot be unflattened,
@@ -225,7 +508,7 @@ class InputObserverInfo:
225
508
  # with the same number of tensors. The function does not check
226
509
  # if that assumption is true.
227
510
  flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
228
- (self._max_args, self._max_kwargs)
511
+ (self._best_candidate.args, self._best_candidate.kwargs)
229
512
  )
230
513
  torch._check(
231
514
  len(flat_inputs) == len(flat_dynamic_shapes),
@@ -234,96 +517,454 @@ class InputObserverInfo:
234
517
  f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
235
518
  ),
236
519
  )
237
- mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)}
238
- ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes(
239
- (self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)]
520
+
521
+ index = 0
522
+
523
+ def change_function(t):
524
+ nonlocal index
525
+ if index >= len(flat_dynamic_shapes):
526
+ raise RuntimeError(
527
+ f"Flattened {index} tensors when there are only "
528
+ f"{len(flat_dynamic_shapes)}."
529
+ )
530
+ res = flat_dynamic_shapes[index]
531
+ index += 1
532
+ return res
533
+
534
+ ds_args, ds_kwargs = _flatten_unflatten_for_dynamic_shapes(
535
+ (self._best_candidate.args, self._best_candidate.kwargs),
536
+ change_function=change_function,
240
537
  )
538
+ if self._best_candidate.cst_kwargs:
539
+ ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
241
540
  if not ds_kwargs:
242
541
  return tuple(ds_args)
243
542
  if not ds_args:
244
- return tuple(ds_kwargs)
245
- pos_names = list(self.signature.parameters)[: len(ds_args)]
543
+ return ds_kwargs
544
+ pos_names = self.signature_names[: len(ds_args)]
246
545
  return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
247
546
 
248
547
  def infer_arguments(
249
- self, index: int | None = None
250
- ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
251
- # This is already checked by build_inputs_completed_with_none_values
548
+ self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False
549
+ ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
550
+ """Infers arguments based on the collected tensors."""
551
+ # This is already checked by _build_inputs_completed_with_none_values
252
552
  # but this is not always well captured by tools checking types.
253
- assert self._max_args is not None and self._max_kwargs is not None
553
+ self.align_inputs_none_values()
554
+ torch._check(self._best_candidate is not None, lambda: "No input was captured.")
555
+ # type checking
556
+ assert self._best_candidate is not None
254
557
  candidate = None
255
- if index is None:
256
- for i, (args_kwargs, spec) in enumerate(zip(self.flat_inputs, self.inputs_specs)):
257
- args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
258
- if len(args) == len(self._max_args) and len(kwargs) == len(self._max_kwargs):
259
- index = i
260
- candidate = args, kwargs
558
+ if index_or_candidate is None:
559
+ for cand in self.inputs:
560
+ args, kwargs = cand.args, cand.kwargs
561
+ if len(args) == len(self._best_candidate.args) and len(kwargs) == len(
562
+ self._best_candidate.kwargs
563
+ ):
564
+ candidate = cand
261
565
  break
262
- if index is not None:
263
- # found one available set.
264
- args, kwargs = candidate or torch.utils._pytree.tree_unflatten(
265
- self.flat_inputs[index], self.inputs_specs[index]
566
+ elif isinstance(index_or_candidate, int):
567
+ torch._check(
568
+ index_or_candidate < len(self.inputs),
569
+ lambda: (
570
+ f"No stored input set for index="
571
+ f"{index_or_candidate}<{len(self.inputs)}."
572
+ ),
266
573
  )
267
- if not kwargs:
268
- return args
269
- if not args:
270
- return kwargs
271
- # We need to move args to kwargs
272
- pos_names = list(self.signature.parameters)[: len(args)]
273
- return {**dict(zip(pos_names, args)), **kwargs}
274
-
275
- raise NotImplementedError(
276
- "We could not find a good set of inputs/outputs. "
277
- "We need to replace none by empty tensors."
574
+ candidate = self.inputs[index_or_candidate]
575
+ else:
576
+ candidate = index_or_candidate
577
+
578
+ torch._check(candidate is not None, "No input was captured.")
579
+ # type checking
580
+ assert candidate is not None
581
+ if candidate.aligned_flat_list is None:
582
+ raise RuntimeError(
583
+ f"Candidate {candidate} has no aligned flat list of tensors, "
584
+ f"index_or_candidate={index_or_candidate}. You should call "
585
+ f"method 'align_with'."
586
+ )
587
+
588
+ aligned_flat_list = candidate.aligned_flat_list
589
+ if any(t is None for t in aligned_flat_list):
590
+ dynamic_shapes = self.infer_dynamic_shapes(return_flat=True)
591
+ # type checking
592
+ assert isinstance(dynamic_shapes, tuple)
593
+ aligned_flat_list = list(aligned_flat_list)
594
+ for index in range(len(aligned_flat_list)):
595
+ if aligned_flat_list[index] is not None:
596
+ continue
597
+ shape = dynamic_shapes[index]
598
+ all_non_empty_tensors = [
599
+ c.aligned_flat_list[index]
600
+ for c in self.inputs
601
+ if c.aligned_flat_list is not None
602
+ ]
603
+ all_non_empty_tensors_not_none = [
604
+ t for t in all_non_empty_tensors if t is not None
605
+ ]
606
+ if not all_non_empty_tensors_not_none:
607
+ raise RuntimeError(
608
+ f"There is no tensor at position {index} in any flattened inputs."
609
+ )
610
+ tensor = all_non_empty_tensors_not_none.pop()
611
+ if tensor.numel() == 0:
612
+ aligned_flat_list[index] = tensor
613
+ continue
614
+ if not shape:
615
+ aligned_flat_list[index] = torch.zeros(
616
+ tensor.shape, dtype=tensor.dtype, device=tensor.device
617
+ )
618
+ continue
619
+ dim = max(shape)
620
+ torch._check(
621
+ dim < tensor.ndim,
622
+ lambda index=index, shape=shape, tshape=tensor.shape: (
623
+ f"Tensor shape {tshape} does not match the "
624
+ f"dynamic shape {shape} at position {index}."
625
+ ),
626
+ )
627
+ new_shape = list(tensor.shape)
628
+ new_shape[dim] = 0
629
+ aligned_flat_list[index] = torch.empty(
630
+ tuple(new_shape), dtype=tensor.dtype, device=tensor.device
631
+ )
632
+ if flat:
633
+ # type checking
634
+ assert all(t is not None for t in aligned_flat_list)
635
+ # pyrefly: ignore[bad-return]
636
+ return aligned_flat_list
637
+ # type checking
638
+ assert candidate is not None
639
+ assert candidate.aligned_spec is not None
640
+ args, kwargs = torch.utils._pytree.tree_unflatten(
641
+ aligned_flat_list, candidate.aligned_spec
278
642
  )
643
+ if self._best_candidate.cst_kwargs:
644
+ kwargs = {**kwargs, **self._best_candidate.cst_kwargs}
645
+
646
+ if not kwargs:
647
+ return args
648
+ if not args:
649
+ return kwargs
650
+ # We need to move args to kwargs
651
+ pos_names = self.signature_names[: len(args)]
652
+ return {**dict(zip(pos_names, args)), **kwargs}
279
653
 
280
654
 
281
655
  class InputObserver:
282
- def __init__(self, store_n_calls: int = 3):
283
- self.store_n_calls = store_n_calls
284
- self.info: InputObserverInfo | None = None
656
+ """Steals forward method to collect inputs and outputs.
657
+ This information is used to infer dynamic shapes and
658
+ export arguments.
659
+
660
+ Args:
661
+ missing: If a named argument (in kwargs) is missing,
662
+ a default value will be taken in this dictionary,
663
+ this is used when after the prefill step, an argument
664
+ disappears (such as `pixel_values`) and another one
665
+ is added (such as `past_key_values`).
666
+ The values are only to infer dynamic shapes and arguments,
667
+ not to run the model.
668
+
669
+ Examples
670
+ --------
671
+ >>> input_observer = InputObserver()
672
+ >>> with input_observer(model):
673
+ >>> model(x1, y1)
674
+ >>> model(x2, y2)
675
+ >>> ep = torch.export.export( # or torch.onnx.export
676
+ >>> model,
677
+ >>> input_observer.infer_arguments(),
678
+ >>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
679
+ >>> )
680
+
681
+ With LLM:
285
682
 
286
- def _forward_captured(self, *args, _captured_forward=None, **kwargs):
287
- assert _captured_forward is not None, "_captured_forward cannot be None"
683
+ >>> input_observer = InputObserver()
684
+ >>> with input_observer(model):
685
+ >>> model.generate(input_ids)
686
+ >>> ep = torch.export.export( # or torch.onnx.export
687
+ >>> model,
688
+ >>> (),
689
+ >>> kwargs=input_observer.infer_arguments(),
690
+ >>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
691
+ >>> )
692
+
693
+ Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
694
+ :ref:`l-plot-whisper-tiny-export-input-observer`,
695
+ :ref:`l-plot-gemma3-tiny-export-input-observer`.
696
+ """
697
+
698
+ def __init__(self, missing: dict[str, Any] | None = None):
699
+ self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked]
700
+ self.missing = missing or {}
701
+
702
+ def _replaced_method(
703
+ self,
704
+ *args,
705
+ _captured_method: Callable | None = None,
706
+ _store_n_calls: int = 3,
707
+ **kwargs,
708
+ ):
709
+ assert _captured_method is not None, "_captured_forward cannot be None"
288
710
  assert self.info is not None, "info cannot be None"
289
711
  n_stored = len(self.info)
290
- if n_stored < self.store_n_calls:
712
+ if n_stored < _store_n_calls:
291
713
  self.info.add_inputs(args, kwargs)
292
- res = _captured_forward(*args, **kwargs)
293
- if n_stored < self.store_n_calls:
294
- self.info.add_outputs(res)
714
+ begin = time.perf_counter()
715
+ res = _captured_method(*args, **kwargs)
716
+ duration = time.perf_counter() - begin
717
+ if n_stored < _store_n_calls:
718
+ self.info.add_outputs(res, latency=duration)
295
719
  return res
296
720
 
721
+ def num_obs(self) -> int:
722
+ """Returns the number of stored set if inputs."""
723
+ return 0 if not self.info else len(self.info)
724
+
297
725
  @contextlib.contextmanager
298
- def __call__(self, model: torch.nn.Module):
299
- if self.info is not None:
300
- raise RuntimeError(
301
- "This class was already used to capture a model. Please create a new one."
302
- )
303
- self.info = InputObserverInfo(signature=inspect.signature(model.forward))
304
- forward_method = model.forward
305
- model.forward = (
306
- lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured(
307
- *args, _captured_forward=_captured_forward, **kwargs
726
+ def __call__(
727
+ self,
728
+ model: torch.nn.Module,
729
+ store_n_calls: int = 3,
730
+ method_name: str = "forward",
731
+ ):
732
+ """Starts collecting inputs and outputs of a specific method.
733
+ The model method is replaced by a new one collecting tensors
734
+ before and after the inner one is called.
735
+ The original method is restored after the collection.
736
+
737
+ Args:
738
+ model: Model
739
+ store_n_calls: The collection stops after this many calls
740
+ to avoid taking too much memory.
741
+ method_name: Method name to spy on.
742
+ """
743
+ if not hasattr(model, method_name):
744
+ raise ValueError(f"Model type {model} does not have a method {method_name!r}.")
745
+ captured_method = getattr(model, method_name)
746
+ sig = inspect.signature(captured_method)
747
+ if self.info is None:
748
+ self.info = InputObserverInfo(
749
+ signature_names=list(sig.parameters),
750
+ default_values={
751
+ p.name: p.default
752
+ for p in sig.parameters.values()
753
+ if p.default != inspect.Parameter.empty
754
+ and isinstance(p.default, (int, bool, str, float))
755
+ },
756
+ missing=self.missing,
308
757
  )
758
+ n_already_stored = len(self.info)
759
+ lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731
760
+ store_n_calls + n_already_stored
761
+ ), **kwargs: self._replaced_method(
762
+ *args, _captured_method=_cm, _store_n_calls=_snc, **kwargs
309
763
  )
764
+
765
+ # It may happen than the signature of the forward is used to trigger a preprocessing.
766
+ # This is used in GenerationMixin (transformers):
767
+ # position_ids_key = "decoder_position_ids" if ... else "position_ids"
768
+ # if position_ids_key in set(inspect.signature(self.forward).parameters.keys()):
769
+ lambda_method.__signature__ = sig # type: ignore[attr-defined]
770
+
771
+ setattr(model, method_name, lambda_method)
772
+
310
773
  try:
311
774
  yield self
312
775
  finally:
313
- model.forward = forward_method
776
+ setattr(model, method_name, captured_method)
314
777
 
315
778
  def _check_captured(self):
316
779
  if self.info is None:
317
780
  raise RuntimeError("No inputs were captured.")
318
781
 
319
- def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
782
+ def infer_dynamic_shapes(
783
+ self, set_batch_dimension_for: set[int | str] | bool | None = None
784
+ ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
785
+ """
786
+ Infers dynamic shapes. Most of the time, models do support a batch dimension
787
+ but this batch dimension has the same value for every input sample.
788
+ Instead of running inference on new samples, argument `set_batch_dimension_for`
789
+ can be used to tell the first dimension is a dynamic dimension for a particular
790
+ set of inputs referenced by their name (str) or their position (int).
791
+
792
+ Args:
793
+ set_batch_dimension_for (set[int | str] | None): A set of input
794
+ identifiers (by position as ``int`` or by name as ``str``) for
795
+ which the first dimension should be treated as a dynamic batch
796
+ dimension. If ``None``, no dimensions are explicitly marked as
797
+ dynamic.
798
+ """
320
799
  self._check_captured()
321
800
  assert self.info is not None # missed by type checking
322
- return self.info.infer_dynamic_shapes()
801
+ return self.info.infer_dynamic_shapes(set_batch_dimension_for=set_batch_dimension_for)
323
802
 
324
803
  def infer_arguments(
325
- self, index: int | None = None
326
- ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
804
+ self,
805
+ index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None,
806
+ flat: bool = False,
807
+ ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
808
+ """Infers arguments based on the collected tensors.
809
+
810
+ Args:
811
+ index_or_args_or_kwargs: If missing, the method selects one set of inputs
812
+ among the available ones, usually this inputs containing
813
+ the set of stored inputs with the highest number of tensors.
814
+ The then replaces None values and missing tensors by empty tensors.
815
+ If not missing, it can be an integer to fetch one of the stored set
816
+ or some inputs.
817
+ flat: If True, it returns a flattened list of tensors,
818
+ if False, it returns a tuple or a dictionary preserving
819
+ the nested structures.
820
+
821
+ Returns:
822
+ Inferred arguments, every optional tensor is replaced by a empty tensor.
823
+ """
327
824
  self._check_captured()
328
825
  assert self.info is not None # missed by type checking
329
- return self.info.infer_arguments(index=index)
826
+ index_or_candidate: int | InputCandidate | None = None
827
+ if index_or_args_or_kwargs is None or isinstance(index_or_args_or_kwargs, int):
828
+ index_or_candidate = index_or_args_or_kwargs
829
+ else:
830
+ if isinstance(index_or_args_or_kwargs, tuple):
831
+ index_or_candidate = InputCandidate(
832
+ args=index_or_args_or_kwargs, kwargs={}, clone=False, cst_kwargs={}
833
+ )
834
+ elif isinstance(index_or_args_or_kwargs, dict):
835
+ index_or_candidate = InputCandidate(
836
+ args=(),
837
+ kwargs={
838
+ k: v
839
+ for k, v in index_or_args_or_kwargs.items()
840
+ if k not in self.info.default_values
841
+ },
842
+ clone=False,
843
+ cst_kwargs={
844
+ k: v
845
+ for k, v in index_or_args_or_kwargs.items()
846
+ if k in self.info.default_values
847
+ },
848
+ )
849
+ else:
850
+ raise ValueError(
851
+ f"Unexpected type {type(index_or_args_or_kwargs)} "
852
+ f"for index_or_args_or_kwargs."
853
+ )
854
+ self.info.align_inputs_none_values()
855
+ # type checking
856
+ assert self.info._best_candidate is not None
857
+ assert self.info._captured_inputs is not None
858
+ index_or_candidate.align_with(
859
+ self.info._best_candidate,
860
+ self.info._captured_inputs,
861
+ self.info.signature_names,
862
+ )
863
+ return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat)
864
+
865
+ def check_discrepancies(
866
+ self,
867
+ onnx_model: str | onnx.ModelProto,
868
+ atol: float = 1e-4,
869
+ rtol: float = 0.1,
870
+ hist=(0.1, 0.01),
871
+ progress_bar: bool = False,
872
+ include_io: bool = True,
873
+ ) -> list[dict[str, str | int | float | bool]]:
874
+ """Computes the discrepancies between the saved inputs and outputs
875
+ with the saved onnx model.
876
+
877
+ Args:
878
+ onnx_model:
879
+ ONNX Model to verify.
880
+ atol:
881
+ Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
882
+ rtol:
883
+ Relative tolerance.
884
+ hist:
885
+ Thresholds, the function determines the number of discrepancies
886
+ above these thresholds.
887
+ progress_bar:
888
+ Shows a progress bar (requires :epkg:`tqdm`).
889
+ include_io:
890
+ Shows inputs/outputs shapes in the summary
891
+ returned by this function.
892
+
893
+ Returns:
894
+ A list of dictionaries, ready to be consumed by a dataframe.
895
+
896
+ The function catches exceptions, it shows the error in the returned
897
+ summary.
898
+ """
899
+ sess = OnnxruntimeEvaluator(onnx_model, whole=True)
900
+ input_names = sess.input_names
901
+ self._check_captured()
902
+ # type checking
903
+ assert self.info is not None
904
+ assert self.info.inputs is not None
905
+ assert self.info.flat_outputs is not None
906
+ assert self.info.latencies is not None
907
+ io_sets = list(zip(self.info.inputs, self.info.flat_outputs, self.info.latencies))
908
+ if progress_bar:
909
+ from tqdm import tqdm
910
+
911
+ loop = tqdm(io_sets)
912
+ else:
913
+ loop = io_sets
914
+ lhist = list(hist)
915
+ data: list[dict[str, Any]] = []
916
+ for inputs, outputs, latency in loop:
917
+ # type checking
918
+ assert inputs.aligned_flat_list is not None
919
+ if len(input_names) != len(inputs.aligned_flat_list):
920
+ raise RuntimeError(
921
+ f"There are ({len(inputs.aligned_flat_list)}) "
922
+ f"tensors but the model expects {len(input_names)}."
923
+ )
924
+ n_none = sum([t is None for t in inputs.aligned_flat_list])
925
+ n_empty = sum([t is None or t.numel() == 0 for t in inputs.aligned_flat_list])
926
+
927
+ feeds = dict(zip(input_names, self.info.infer_arguments(inputs, flat=True)))
928
+
929
+ begin = time.perf_counter()
930
+ try:
931
+ ort_outputs = sess.run(None, feeds)
932
+ error = None
933
+ except Exception as e:
934
+ error = str(e)
935
+ ort_outputs = None
936
+
937
+ duration = time.perf_counter() - begin
938
+ if error:
939
+ diff: dict[str, str | int | float | bool] = dict(error=error, SUCCESS=False)
940
+ else:
941
+ # The last output may be empty and torch could skip it.
942
+ if isinstance(outputs, list) and isinstance(ort_outputs, list):
943
+ while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0:
944
+ ort_outputs.pop()
945
+ diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment]
946
+ if "rep" in diff and isinstance(diff["rep"], dict):
947
+ diff.update(diff["rep"])
948
+ del diff["rep"]
949
+ diff["SUCCESS"] = (
950
+ isinstance(diff["abs"], float)
951
+ and isinstance(diff["rel"], float)
952
+ and diff["abs"] < atol
953
+ and diff["rel"] < rtol
954
+ )
955
+ diff.update(
956
+ dict(
957
+ index=len(diff),
958
+ duration_torch=latency,
959
+ ort_duration=duration,
960
+ n_inputs=len(input_names),
961
+ n_none=n_none,
962
+ n_empty=n_empty,
963
+ )
964
+ )
965
+ if include_io:
966
+ diff["inputs"] = string_type(feeds, with_shape=True)
967
+ diff["outputs_torch"] = string_type(outputs, with_shape=True)
968
+ diff["outputs_ort"] = string_type(ort_outputs, with_shape=True)
969
+ data.append(diff)
970
+ return data