onnx-diagnostic 0.8.11__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
- onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
- onnx_diagnostic/ci_models/data/__init__.py +0 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +8 -3
- onnx_diagnostic/export/api.py +11 -0
- onnx_diagnostic/export/dynamic_shapes.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +96 -30
- onnx_diagnostic/helpers/helper.py +39 -0
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +5 -1
- onnx_diagnostic/helpers/rt_helper.py +53 -9
- onnx_diagnostic/helpers/torch_helper.py +7 -2
- onnx_diagnostic/investigate/input_observer.py +793 -152
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +28 -2
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +24 -21
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import inspect
|
|
3
|
+
import time
|
|
3
4
|
from typing import Any, Callable, Sequence
|
|
5
|
+
import onnx
|
|
4
6
|
import torch
|
|
7
|
+
from ..helpers import max_diff, string_type
|
|
8
|
+
from ..reference import OnnxruntimeEvaluator
|
|
5
9
|
|
|
10
|
+
EOL = "\n"
|
|
6
11
|
|
|
7
|
-
|
|
12
|
+
|
|
13
|
+
def _flatten_unflatten_for_dynamic_shapes(
|
|
8
14
|
obj: Any,
|
|
9
15
|
use_dict: bool = True,
|
|
10
16
|
change_function: Callable[[torch.Tensor], Any] | None = None,
|
|
11
17
|
) -> Any:
|
|
12
|
-
"""
|
|
13
|
-
Returns the object in a different structure similar to what
|
|
18
|
+
"""Returns the object in a different structure similar to what
|
|
14
19
|
the definition of the dynamic shapes should use.
|
|
15
20
|
|
|
16
21
|
Args:
|
|
@@ -27,7 +32,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
27
32
|
like replace them by a shape
|
|
28
33
|
|
|
29
34
|
Returns:
|
|
30
|
-
the
|
|
35
|
+
the flattened object
|
|
31
36
|
"""
|
|
32
37
|
if isinstance(obj, torch.Tensor):
|
|
33
38
|
return change_function(obj) if change_function else obj
|
|
@@ -38,7 +43,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
38
43
|
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
39
44
|
end += subspec.num_leaves
|
|
40
45
|
value = subspec.unflatten(flat[start:end])
|
|
41
|
-
value =
|
|
46
|
+
value = _flatten_unflatten_for_dynamic_shapes(
|
|
42
47
|
value, use_dict=use_dict, change_function=change_function
|
|
43
48
|
)
|
|
44
49
|
subtrees.append(value)
|
|
@@ -66,157 +71,435 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
66
71
|
return subtrees
|
|
67
72
|
|
|
68
73
|
|
|
69
|
-
def
|
|
70
|
-
|
|
71
|
-
|
|
74
|
+
def _infer_dynamic_dimensions(
|
|
75
|
+
shape_list: Sequence[tuple[int, ...]], set_batch_dimension: bool = False
|
|
76
|
+
) -> list[int]:
|
|
77
|
+
"""Returns the list of dynamic dimensions given a list of shapes
|
|
72
78
|
corresponding to the same tensor.
|
|
73
79
|
|
|
74
80
|
Args:
|
|
75
81
|
shape_list:
|
|
76
82
|
list of shapes, they must all have the same length
|
|
83
|
+
set_batch_dimension:
|
|
84
|
+
forces the first dimension to be treated as dynamic,
|
|
85
|
+
even if all shapes have the same value for that dimension
|
|
77
86
|
|
|
78
87
|
Returns:
|
|
79
88
|
list of dynamic dimensions
|
|
80
89
|
"""
|
|
81
90
|
unique_ranks = {len(shape) for shape in shape_list}
|
|
82
91
|
torch._check(
|
|
83
|
-
len(unique_ranks) == 1,
|
|
92
|
+
len(unique_ranks) == 1,
|
|
93
|
+
lambda: "all shapes in shape_list must have the same rank",
|
|
84
94
|
)
|
|
85
95
|
rank = unique_ranks.pop()
|
|
86
96
|
dynamic = []
|
|
87
97
|
for i in range(rank):
|
|
88
98
|
dims = [shape[i] for shape in shape_list]
|
|
89
|
-
if len(set(dims)) > 1:
|
|
99
|
+
if len(set(dims)) > 1 or (i == 0 and set_batch_dimension):
|
|
90
100
|
dynamic.append(i)
|
|
91
101
|
return dynamic
|
|
92
102
|
|
|
93
103
|
|
|
104
|
+
class InputCandidate:
|
|
105
|
+
"""Retains one set of inputs given to the forward method or any
|
|
106
|
+
other method the class :class:`InputObserver` is stealing from.
|
|
107
|
+
Any class is allowed as long as it can be flattened.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
args:
|
|
111
|
+
Positional arguments.
|
|
112
|
+
kwargs:
|
|
113
|
+
Optional arguments.
|
|
114
|
+
clone:
|
|
115
|
+
Clones the inputs before storing them. Some tensors
|
|
116
|
+
may be modified inplace, the original value must be retained.
|
|
117
|
+
cst_kwargs:
|
|
118
|
+
Any optional arguments constant over multiple calls.
|
|
119
|
+
int, float, str, bool values must be stored here.
|
|
120
|
+
|
|
121
|
+
The constructor flattens the received arguments.
|
|
122
|
+
Any necessary flattening function should have been registered first.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
args: tuple[Any, ...],
|
|
128
|
+
kwargs: dict[str, Any],
|
|
129
|
+
clone: bool,
|
|
130
|
+
cst_kwargs: dict[str, int | str | float | bool],
|
|
131
|
+
):
|
|
132
|
+
self.args = args
|
|
133
|
+
self.kwargs = kwargs
|
|
134
|
+
self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
135
|
+
self.n_tensors = sum(t is not None for t in self.flat_list)
|
|
136
|
+
self._position_to_args_kwargs: list[int | str] | None = None
|
|
137
|
+
self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None
|
|
138
|
+
self.cst_kwargs = cst_kwargs.copy()
|
|
139
|
+
assert "use_cache" not in self.cst_kwargs
|
|
140
|
+
|
|
141
|
+
if clone:
|
|
142
|
+
self.flat_list = [
|
|
143
|
+
(None if not isinstance(t, torch.Tensor) else t.clone().detach())
|
|
144
|
+
for t in self.flat_list
|
|
145
|
+
]
|
|
146
|
+
self.args, self.kwargs = torch.utils._pytree.tree_unflatten(
|
|
147
|
+
self.flat_list, self.spec
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None
|
|
151
|
+
self.aligned_flat_list: list[torch.Tensor | None] | None = None
|
|
152
|
+
|
|
153
|
+
def __str__(self) -> str:
|
|
154
|
+
return (
|
|
155
|
+
f"{self.__class__.__name__}({len(self.args)} args, "
|
|
156
|
+
f"{len(self.kwargs)} kwargs, {len(self.flat_list)} tensors, "
|
|
157
|
+
f"{len(self.aligned_flat_list or [])} aligned tensors)"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def __len__(self) -> int:
|
|
161
|
+
"""Returns the number of flattended tensors, None tensors are included."""
|
|
162
|
+
return len(self.flat_list)
|
|
163
|
+
|
|
164
|
+
def str_obs(self) -> str:
|
|
165
|
+
"""Prints out some information about the osbervations."""
|
|
166
|
+
return (
|
|
167
|
+
f"InputCandidate(args={string_type(self.args, with_shape=True)}, "
|
|
168
|
+
f"kwargs={string_type(self.kwargs, with_shape=True)}, "
|
|
169
|
+
f"cst_kwargs={self.cst_kwargs})"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def build_mappings(self) -> list[int | str]:
|
|
173
|
+
if self._position_to_args_kwargs is not None:
|
|
174
|
+
return self._position_to_args_kwargs
|
|
175
|
+
self._n_tensors_for_args_kwargs = {}
|
|
176
|
+
|
|
177
|
+
flat_index_to_args: list[int | str] = []
|
|
178
|
+
for index_args, a in enumerate(self.args):
|
|
179
|
+
size = len(torch.utils._pytree.tree_flatten(a)[0])
|
|
180
|
+
self._n_tensors_for_args_kwargs[index_args] = size
|
|
181
|
+
flat_index_to_args.extend([index_args] * size)
|
|
182
|
+
for k, v in self.kwargs.items():
|
|
183
|
+
size = len(torch.utils._pytree.tree_flatten(v)[0])
|
|
184
|
+
self._n_tensors_for_args_kwargs[k] = size
|
|
185
|
+
flat_index_to_args.extend([k] * size)
|
|
186
|
+
|
|
187
|
+
self._position_to_args_kwargs = flat_index_to_args
|
|
188
|
+
return self._position_to_args_kwargs
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def position_to_args_kwargs(self) -> list[int | str]:
|
|
192
|
+
"""Returns the corresponding args or kwargs
|
|
193
|
+
for every tensor in the flattened inputs.
|
|
194
|
+
"""
|
|
195
|
+
if self._position_to_args_kwargs is None:
|
|
196
|
+
self.build_mappings()
|
|
197
|
+
# type checking is missing it
|
|
198
|
+
assert self._position_to_args_kwargs is not None
|
|
199
|
+
return self._position_to_args_kwargs
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def n_tensors_for_args_kwargs(self) -> dict[int | str, int]:
|
|
203
|
+
"""Returns the number of flat tensors in every args or kwargs."""
|
|
204
|
+
if self._n_tensors_for_args_kwargs is None:
|
|
205
|
+
self.build_mappings()
|
|
206
|
+
# type checking is missing it
|
|
207
|
+
assert self._n_tensors_for_args_kwargs is not None
|
|
208
|
+
return self._n_tensors_for_args_kwargs
|
|
209
|
+
|
|
210
|
+
def _set_aligned_flat_list(
|
|
211
|
+
self,
|
|
212
|
+
aligned_flat_list: list[torch.Tensor | None],
|
|
213
|
+
aligned_spec: torch.utils._pytree.PyTreeSpec,
|
|
214
|
+
):
|
|
215
|
+
self.aligned_flat_list = aligned_flat_list
|
|
216
|
+
self.aligned_spec = aligned_spec
|
|
217
|
+
|
|
218
|
+
def align_with(
|
|
219
|
+
self,
|
|
220
|
+
best_candidate: "InputCandidate",
|
|
221
|
+
captured_inputs: dict[int | str, int],
|
|
222
|
+
signature_names: list[str],
|
|
223
|
+
):
|
|
224
|
+
"""Two candidates are considered as aligned if after being flattened
|
|
225
|
+
if they have the same number of tensors (None allowed)."""
|
|
226
|
+
if self.cst_kwargs != best_candidate.cst_kwargs:
|
|
227
|
+
raise RuntimeError(
|
|
228
|
+
f"Two calls were made with different constant values, "
|
|
229
|
+
f"{self.cst_kwargs} != {best_candidate.cst_kwargs}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
args = self.args
|
|
233
|
+
if len(self.args) > len(best_candidate.args):
|
|
234
|
+
# We need to move some args to kwargs as the best_candidate does.
|
|
235
|
+
new_kwargs = {}
|
|
236
|
+
for i in range(len(best_candidate.args), len(self.args)):
|
|
237
|
+
new_kwargs[signature_names[i]] = args[i]
|
|
238
|
+
args = args[: len(best_candidate.args)]
|
|
239
|
+
kwargs = {**new_kwargs, **self.kwargs}
|
|
240
|
+
else:
|
|
241
|
+
kwargs = self.kwargs
|
|
242
|
+
|
|
243
|
+
flat = []
|
|
244
|
+
for i in range(len(best_candidate.args)):
|
|
245
|
+
if i < len(args) and (isinstance(args[i], torch.Tensor) or args[i]):
|
|
246
|
+
ts = torch.utils._pytree.tree_flatten(self.args[i])[0]
|
|
247
|
+
if i in captured_inputs and captured_inputs[i] != len(ts):
|
|
248
|
+
raise RuntimeError(
|
|
249
|
+
f"Positional argument {i} has {len(ts)} tensors "
|
|
250
|
+
f"but previously got {captured_inputs[i]} tensors. "
|
|
251
|
+
f"Inference is impossible in that case."
|
|
252
|
+
)
|
|
253
|
+
captured_inputs[i] = len(ts)
|
|
254
|
+
flat.extend(ts)
|
|
255
|
+
continue
|
|
256
|
+
# If the argument i is not specified or is None or an empty container.
|
|
257
|
+
flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[i])])
|
|
258
|
+
|
|
259
|
+
for k in best_candidate.kwargs:
|
|
260
|
+
if k in kwargs and (isinstance(kwargs[k], torch.Tensor) or kwargs[k]):
|
|
261
|
+
ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
|
|
262
|
+
if k in captured_inputs and captured_inputs[k] != len(ts):
|
|
263
|
+
raise RuntimeError(
|
|
264
|
+
f"Named argument {k!r} has {len(ts)} tensors "
|
|
265
|
+
f"but previously got {captured_inputs[k]} tensors in "
|
|
266
|
+
f"kwargs={list(kwargs)}. "
|
|
267
|
+
f"Inference is impossible in that case."
|
|
268
|
+
)
|
|
269
|
+
captured_inputs[k] = len(ts)
|
|
270
|
+
flat.extend(ts)
|
|
271
|
+
continue
|
|
272
|
+
# If the argument k is not specified or is None or an empty container.
|
|
273
|
+
flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[k])])
|
|
274
|
+
|
|
275
|
+
self._set_aligned_flat_list(flat, best_candidate.spec)
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def n_aligned_tensors(self) -> int:
|
|
279
|
+
if self.aligned_flat_list is None:
|
|
280
|
+
raise RuntimeError("This input was not aligned with the others.")
|
|
281
|
+
return len(self.aligned_flat_list)
|
|
282
|
+
|
|
283
|
+
|
|
94
284
|
class InputObserverInfo:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
285
|
+
"""Contains all the necessary information to infer dynamic shapes
|
|
286
|
+
and the arguments to send to :func:`torch.export.export`.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
signature_names: Names of the arguments of the method
|
|
290
|
+
the collector tensors come from. They are used if it becomes
|
|
291
|
+
necessary to move positional arguments to named ones.
|
|
292
|
+
They are used a second time because :func:`torch.export.export`
|
|
293
|
+
cares about the order in kwargs and dynamic shapes, it needs
|
|
294
|
+
to be the same in the ordered dictionaries `add_inputs` receive.
|
|
295
|
+
default_values: Default values defined by the signature of the function,
|
|
296
|
+
any value equal to that is ignore to simplify the export.
|
|
297
|
+
missing: If a named argument (in kwargs) is missing,
|
|
298
|
+
a default value will be taken in this dictionary,
|
|
299
|
+
this is used when after the prefill step, an argument
|
|
300
|
+
disappears (such as `pixel_values`) and another one
|
|
301
|
+
is added (such as `past_key_values`).
|
|
302
|
+
The values are only to infer dynamic shapes and arguments,
|
|
303
|
+
not to run the model.
|
|
304
|
+
"""
|
|
103
305
|
|
|
104
|
-
|
|
105
|
-
self
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
signature_names: list[str],
|
|
309
|
+
default_values: dict[str, int | bool | str | float],
|
|
310
|
+
missing: dict[str, Any],
|
|
311
|
+
):
|
|
312
|
+
self.default_values = default_values
|
|
313
|
+
self.missing = missing
|
|
314
|
+
self.inputs: list[InputCandidate] = []
|
|
315
|
+
self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
|
|
316
|
+
self.flat_outputs: list[list[torch.Tensor | None]] = []
|
|
317
|
+
self.latencies: list[float] = []
|
|
318
|
+
self.signature_names = signature_names
|
|
319
|
+
self._best_candidate: InputCandidate | None = None
|
|
320
|
+
self._captured_inputs: dict[int | str, int] | None = None
|
|
106
321
|
|
|
107
322
|
def __len__(self) -> int:
|
|
108
|
-
|
|
323
|
+
"""Returns the number of collected set of inputs/outputs."""
|
|
324
|
+
return len(self.inputs)
|
|
109
325
|
|
|
110
326
|
def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
|
|
327
|
+
"""Stores one set of inputs. They are deepcopied.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
args: Positional arguments.
|
|
331
|
+
kwargs: Named arguments.
|
|
332
|
+
"""
|
|
333
|
+
cst_kwargs = {
|
|
334
|
+
k: v
|
|
335
|
+
for k, v in kwargs.items()
|
|
336
|
+
if k in self.signature_names
|
|
337
|
+
and isinstance(v, (int, float, bool, str))
|
|
338
|
+
and v != self.default_values.get(k, None)
|
|
339
|
+
and self.default_values.get(k, None) is not None
|
|
340
|
+
}
|
|
111
341
|
kwargs = {
|
|
112
342
|
k: v
|
|
113
343
|
for k, v in kwargs.items()
|
|
114
344
|
if v is not None and not isinstance(v, (int, float, bool))
|
|
115
345
|
}
|
|
116
|
-
flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
117
|
-
self.inputs_specs.append(spec)
|
|
118
|
-
cloned = [
|
|
119
|
-
(None if not isinstance(t, torch.Tensor) else t.clone().detach())
|
|
120
|
-
for t in flat_args
|
|
121
|
-
]
|
|
122
|
-
self.flat_inputs.append(cloned)
|
|
123
346
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
347
|
+
# adds missing attributes
|
|
348
|
+
for k, v in self.missing.items():
|
|
349
|
+
if k not in kwargs:
|
|
350
|
+
kwargs[k] = v
|
|
351
|
+
|
|
352
|
+
# kwargs may come in a different ordeer teach.
|
|
353
|
+
# dictionaries are ordered and torch.export.export expects
|
|
354
|
+
# dynamic shapes an kwargs to follow the same order.
|
|
355
|
+
|
|
356
|
+
ordered_kwargs = {k: kwargs[k] for k in self.signature_names if k in kwargs}
|
|
357
|
+
for k, v in kwargs.items():
|
|
358
|
+
if k not in ordered_kwargs:
|
|
359
|
+
ordered_kwargs[k] = v
|
|
360
|
+
|
|
361
|
+
candidate = InputCandidate(args, ordered_kwargs, clone=True, cst_kwargs=cst_kwargs)
|
|
362
|
+
self.inputs.append(candidate)
|
|
363
|
+
if self._best_candidate is None or len(self._best_candidate) < len(candidate):
|
|
364
|
+
self._best_candidate = candidate
|
|
129
365
|
|
|
130
|
-
def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
|
|
366
|
+
def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...], latency: float):
|
|
367
|
+
"""Stores outputs. They are deepcopied."""
|
|
131
368
|
flat_res, spec = torch.utils._pytree.tree_flatten(res)
|
|
132
369
|
self.outputs_specs.append(spec)
|
|
133
|
-
self.flat_outputs.append(
|
|
370
|
+
self.flat_outputs.append(
|
|
371
|
+
[(None if t is None else t.clone().detach()) for t in flat_res]
|
|
372
|
+
)
|
|
373
|
+
self.latencies.append(latency)
|
|
134
374
|
|
|
135
|
-
def
|
|
136
|
-
|
|
137
|
-
|
|
375
|
+
def align_inputs_none_values(self):
|
|
376
|
+
"""Once the best candidate is chosen, this method aligns every set of inputs
|
|
377
|
+
on the best candidate, it inserts None at the right position when
|
|
378
|
+
optional inputs are not specified. We consider a set of inputs is aligned
|
|
379
|
+
if this method does not change the original flattened inputs.
|
|
380
|
+
"""
|
|
381
|
+
if not self.inputs or self._best_candidate is None:
|
|
138
382
|
raise RuntimeError("No inputs were captured.")
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
383
|
+
|
|
384
|
+
if all(candidate.aligned_flat_list is not None for candidate in self.inputs):
|
|
385
|
+
# No new inputs, no alignment is necessary.
|
|
386
|
+
return
|
|
143
387
|
|
|
144
388
|
# Let's reprocess everything.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
389
|
+
self._captured_inputs = {}
|
|
390
|
+
for candidate in self.inputs:
|
|
391
|
+
if len(set(candidate.kwargs) | set(self._best_candidate.kwargs)) > len(
|
|
392
|
+
self._best_candidate.kwargs
|
|
393
|
+
):
|
|
150
394
|
raise RuntimeError(
|
|
151
|
-
"At least one call to the observed model "
|
|
152
|
-
"must contain all the named arguments."
|
|
395
|
+
f"At least one call to the observed model "
|
|
396
|
+
f"must contain all the named arguments. "
|
|
397
|
+
f"candidate kwargs={list(candidate.kwargs)}, "
|
|
398
|
+
f"best candidate kwargs={list(self._best_candidate.kwargs)}, "
|
|
399
|
+
f"all candidate kwargs={EOL}"
|
|
400
|
+
f"{EOL.join(string_type(c.kwargs, with_shape=True) for c in self.inputs)}"
|
|
153
401
|
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
402
|
+
candidate.align_with(
|
|
403
|
+
self._best_candidate, self._captured_inputs, self.signature_names
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def infer_dynamic_shapes(
|
|
407
|
+
self,
|
|
408
|
+
set_batch_dimension_for: set[int | str] | bool | None = None,
|
|
409
|
+
return_flat: bool = False,
|
|
410
|
+
) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
|
|
411
|
+
"""Infers dynamic shapes based on the collected tensors.
|
|
412
|
+
Most of the time, models do support a batch dimension
|
|
413
|
+
but this batch dimension has the same value for every input sample.
|
|
414
|
+
Instead of running inference on new samples, argument `set_batch_dimension_for`
|
|
415
|
+
can be used to tell the first dimension is a dynamic dimension for a particular
|
|
416
|
+
set of inputs referenced by their name (str) or their position (int).
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
set_batch_dimension_for (set[int | str] | None): Set of input identifiers,
|
|
420
|
+
by name (``str``) or position (``int``), for which the first dimension
|
|
421
|
+
should be treated as a dynamic batch dimension. If ``None`` or empty,
|
|
422
|
+
no additional batch dimensions are marked as dynamic.
|
|
423
|
+
return_flat: Tells the function to return a flat tuple instead of
|
|
424
|
+
nested structured.
|
|
425
|
+
"""
|
|
426
|
+
self.align_inputs_none_values()
|
|
427
|
+
# type checking
|
|
428
|
+
assert self._best_candidate is not None
|
|
429
|
+
assert self._best_candidate.flat_list is not None
|
|
430
|
+
assert self._best_candidate.aligned_flat_list is not None
|
|
431
|
+
|
|
432
|
+
def _set_batch_dimension(name_or_position):
|
|
433
|
+
if not set_batch_dimension_for:
|
|
434
|
+
return False
|
|
435
|
+
if (
|
|
436
|
+
isinstance(set_batch_dimension_for, bool) and set_batch_dimension_for
|
|
437
|
+
) or name_or_position in set_batch_dimension_for:
|
|
438
|
+
return True
|
|
439
|
+
if isinstance(name_or_position, int):
|
|
440
|
+
torch._check(
|
|
441
|
+
name_or_position < len(self.signature_names),
|
|
442
|
+
lambda: f"argument at position {name_or_position} is out of boundary",
|
|
443
|
+
)
|
|
444
|
+
if self.signature_names[name_or_position] in set_batch_dimension_for:
|
|
445
|
+
return True
|
|
446
|
+
return False
|
|
447
|
+
|
|
448
|
+
def _set_batch_dimension_for_flat_index(index):
|
|
449
|
+
# type checking
|
|
450
|
+
assert self._best_candidate is not None
|
|
451
|
+
return _set_batch_dimension(self._best_candidate.position_to_args_kwargs[index])
|
|
452
|
+
|
|
453
|
+
if len(self._best_candidate.flat_list) != len(self._best_candidate.aligned_flat_list):
|
|
190
454
|
raise NotImplementedError(
|
|
191
455
|
"infer_dynamic_shapes is not implemented "
|
|
192
|
-
"when the
|
|
456
|
+
"when the best candidate is not 'aligned'."
|
|
457
|
+
"This happens when there is not stored set inputs where "
|
|
458
|
+
"all optional inputs showing in other sets are defined."
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if len({inputs.n_aligned_tensors for inputs in self.inputs}) != 1:
|
|
462
|
+
raise NotImplementedError(
|
|
463
|
+
f"infer_dynamic_shapes is not implemented "
|
|
464
|
+
f"when the number of input tensors are not the same in "
|
|
465
|
+
f"every set of inputs "
|
|
466
|
+
f"{[inputs.n_aligned_tensors for inputs in self.inputs]}."
|
|
193
467
|
)
|
|
194
468
|
shape_lists = [
|
|
195
|
-
[(None if t is None else t.shape) for t in
|
|
469
|
+
[(None if t is None else t.shape) for t in candidate.aligned_flat_list]
|
|
470
|
+
for candidate in self.inputs
|
|
471
|
+
if candidate.aligned_flat_list is not None
|
|
196
472
|
]
|
|
197
473
|
n_tensors = len(shape_lists[0])
|
|
198
474
|
dynamic_shapes = [
|
|
199
|
-
|
|
200
|
-
[s for s in [shapes[index] for shapes in shape_lists] if s is not None]
|
|
475
|
+
_infer_dynamic_dimensions(
|
|
476
|
+
[s for s in [shapes[index] for shapes in shape_lists] if s is not None],
|
|
477
|
+
set_batch_dimension=_set_batch_dimension_for_flat_index(index),
|
|
201
478
|
)
|
|
202
479
|
for index in range(n_tensors)
|
|
203
480
|
]
|
|
204
481
|
cst = torch.export.Dim.DYNAMIC
|
|
205
482
|
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
|
|
206
|
-
if
|
|
483
|
+
if return_flat:
|
|
484
|
+
return tuple(flat_dynamic_shapes)
|
|
485
|
+
if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len(
|
|
486
|
+
self._best_candidate.kwargs
|
|
487
|
+
):
|
|
207
488
|
# It means forward method is called with tensors only.
|
|
208
|
-
if not self.
|
|
489
|
+
if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs:
|
|
209
490
|
# only positional arguments
|
|
210
491
|
return tuple(flat_dynamic_shapes)
|
|
211
|
-
if not self.
|
|
492
|
+
if not self._best_candidate.args:
|
|
212
493
|
# only named arguments
|
|
213
|
-
|
|
494
|
+
ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes))
|
|
495
|
+
return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
|
|
214
496
|
# positional arguments needs to be moved to the named arguments
|
|
215
|
-
n_args = len(self.
|
|
216
|
-
pos_names =
|
|
497
|
+
n_args = len(self._best_candidate.args)
|
|
498
|
+
pos_names = self.signature_names[:n_args]
|
|
217
499
|
return {
|
|
218
500
|
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
|
|
219
|
-
**dict(zip(list(self.
|
|
501
|
+
**dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])),
|
|
502
|
+
**dict.fromkeys(self._best_candidate.cst_kwargs, None),
|
|
220
503
|
}
|
|
221
504
|
|
|
222
505
|
# nested types, here comes the fun part because the shapes cannot be unflattened,
|
|
@@ -225,7 +508,7 @@ class InputObserverInfo:
|
|
|
225
508
|
# with the same number of tensors. The function does not check
|
|
226
509
|
# if that assumption is true.
|
|
227
510
|
flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
|
|
228
|
-
(self.
|
|
511
|
+
(self._best_candidate.args, self._best_candidate.kwargs)
|
|
229
512
|
)
|
|
230
513
|
torch._check(
|
|
231
514
|
len(flat_inputs) == len(flat_dynamic_shapes),
|
|
@@ -234,96 +517,454 @@ class InputObserverInfo:
|
|
|
234
517
|
f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
|
|
235
518
|
),
|
|
236
519
|
)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
520
|
+
|
|
521
|
+
index = 0
|
|
522
|
+
|
|
523
|
+
def change_function(t):
|
|
524
|
+
nonlocal index
|
|
525
|
+
if index >= len(flat_dynamic_shapes):
|
|
526
|
+
raise RuntimeError(
|
|
527
|
+
f"Flattened {index} tensors when there are only "
|
|
528
|
+
f"{len(flat_dynamic_shapes)}."
|
|
529
|
+
)
|
|
530
|
+
res = flat_dynamic_shapes[index]
|
|
531
|
+
index += 1
|
|
532
|
+
return res
|
|
533
|
+
|
|
534
|
+
ds_args, ds_kwargs = _flatten_unflatten_for_dynamic_shapes(
|
|
535
|
+
(self._best_candidate.args, self._best_candidate.kwargs),
|
|
536
|
+
change_function=change_function,
|
|
240
537
|
)
|
|
538
|
+
if self._best_candidate.cst_kwargs:
|
|
539
|
+
ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
|
|
241
540
|
if not ds_kwargs:
|
|
242
541
|
return tuple(ds_args)
|
|
243
542
|
if not ds_args:
|
|
244
|
-
return
|
|
245
|
-
pos_names =
|
|
543
|
+
return ds_kwargs
|
|
544
|
+
pos_names = self.signature_names[: len(ds_args)]
|
|
246
545
|
return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
|
|
247
546
|
|
|
248
547
|
def infer_arguments(
|
|
249
|
-
self,
|
|
250
|
-
) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
251
|
-
|
|
548
|
+
self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False
|
|
549
|
+
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
550
|
+
"""Infers arguments based on the collected tensors."""
|
|
551
|
+
# This is already checked by _build_inputs_completed_with_none_values
|
|
252
552
|
# but this is not always well captured by tools checking types.
|
|
253
|
-
|
|
553
|
+
self.align_inputs_none_values()
|
|
554
|
+
torch._check(self._best_candidate is not None, lambda: "No input was captured.")
|
|
555
|
+
# type checking
|
|
556
|
+
assert self._best_candidate is not None
|
|
254
557
|
candidate = None
|
|
255
|
-
if
|
|
256
|
-
for
|
|
257
|
-
args, kwargs =
|
|
258
|
-
if len(args) == len(self.
|
|
259
|
-
|
|
260
|
-
|
|
558
|
+
if index_or_candidate is None:
|
|
559
|
+
for cand in self.inputs:
|
|
560
|
+
args, kwargs = cand.args, cand.kwargs
|
|
561
|
+
if len(args) == len(self._best_candidate.args) and len(kwargs) == len(
|
|
562
|
+
self._best_candidate.kwargs
|
|
563
|
+
):
|
|
564
|
+
candidate = cand
|
|
261
565
|
break
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
566
|
+
elif isinstance(index_or_candidate, int):
|
|
567
|
+
torch._check(
|
|
568
|
+
index_or_candidate < len(self.inputs),
|
|
569
|
+
lambda: (
|
|
570
|
+
f"No stored input set for index="
|
|
571
|
+
f"{index_or_candidate}<{len(self.inputs)}."
|
|
572
|
+
),
|
|
266
573
|
)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
574
|
+
candidate = self.inputs[index_or_candidate]
|
|
575
|
+
else:
|
|
576
|
+
candidate = index_or_candidate
|
|
577
|
+
|
|
578
|
+
torch._check(candidate is not None, "No input was captured.")
|
|
579
|
+
# type checking
|
|
580
|
+
assert candidate is not None
|
|
581
|
+
if candidate.aligned_flat_list is None:
|
|
582
|
+
raise RuntimeError(
|
|
583
|
+
f"Candidate {candidate} has no aligned flat list of tensors, "
|
|
584
|
+
f"index_or_candidate={index_or_candidate}. You should call "
|
|
585
|
+
f"method 'align_with'."
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
aligned_flat_list = candidate.aligned_flat_list
|
|
589
|
+
if any(t is None for t in aligned_flat_list):
|
|
590
|
+
dynamic_shapes = self.infer_dynamic_shapes(return_flat=True)
|
|
591
|
+
# type checking
|
|
592
|
+
assert isinstance(dynamic_shapes, tuple)
|
|
593
|
+
aligned_flat_list = list(aligned_flat_list)
|
|
594
|
+
for index in range(len(aligned_flat_list)):
|
|
595
|
+
if aligned_flat_list[index] is not None:
|
|
596
|
+
continue
|
|
597
|
+
shape = dynamic_shapes[index]
|
|
598
|
+
all_non_empty_tensors = [
|
|
599
|
+
c.aligned_flat_list[index]
|
|
600
|
+
for c in self.inputs
|
|
601
|
+
if c.aligned_flat_list is not None
|
|
602
|
+
]
|
|
603
|
+
all_non_empty_tensors_not_none = [
|
|
604
|
+
t for t in all_non_empty_tensors if t is not None
|
|
605
|
+
]
|
|
606
|
+
if not all_non_empty_tensors_not_none:
|
|
607
|
+
raise RuntimeError(
|
|
608
|
+
f"There is no tensor at position {index} in any flattened inputs."
|
|
609
|
+
)
|
|
610
|
+
tensor = all_non_empty_tensors_not_none.pop()
|
|
611
|
+
if tensor.numel() == 0:
|
|
612
|
+
aligned_flat_list[index] = tensor
|
|
613
|
+
continue
|
|
614
|
+
if not shape:
|
|
615
|
+
aligned_flat_list[index] = torch.zeros(
|
|
616
|
+
tensor.shape, dtype=tensor.dtype, device=tensor.device
|
|
617
|
+
)
|
|
618
|
+
continue
|
|
619
|
+
dim = max(shape)
|
|
620
|
+
torch._check(
|
|
621
|
+
dim < tensor.ndim,
|
|
622
|
+
lambda index=index, shape=shape, tshape=tensor.shape: (
|
|
623
|
+
f"Tensor shape {tshape} does not match the "
|
|
624
|
+
f"dynamic shape {shape} at position {index}."
|
|
625
|
+
),
|
|
626
|
+
)
|
|
627
|
+
new_shape = list(tensor.shape)
|
|
628
|
+
new_shape[dim] = 0
|
|
629
|
+
aligned_flat_list[index] = torch.empty(
|
|
630
|
+
tuple(new_shape), dtype=tensor.dtype, device=tensor.device
|
|
631
|
+
)
|
|
632
|
+
if flat:
|
|
633
|
+
# type checking
|
|
634
|
+
assert all(t is not None for t in aligned_flat_list)
|
|
635
|
+
# pyrefly: ignore[bad-return]
|
|
636
|
+
return aligned_flat_list
|
|
637
|
+
# type checking
|
|
638
|
+
assert candidate is not None
|
|
639
|
+
assert candidate.aligned_spec is not None
|
|
640
|
+
args, kwargs = torch.utils._pytree.tree_unflatten(
|
|
641
|
+
aligned_flat_list, candidate.aligned_spec
|
|
278
642
|
)
|
|
643
|
+
if self._best_candidate.cst_kwargs:
|
|
644
|
+
kwargs = {**kwargs, **self._best_candidate.cst_kwargs}
|
|
645
|
+
|
|
646
|
+
if not kwargs:
|
|
647
|
+
return args
|
|
648
|
+
if not args:
|
|
649
|
+
return kwargs
|
|
650
|
+
# We need to move args to kwargs
|
|
651
|
+
pos_names = self.signature_names[: len(args)]
|
|
652
|
+
return {**dict(zip(pos_names, args)), **kwargs}
|
|
279
653
|
|
|
280
654
|
|
|
281
655
|
class InputObserver:
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
656
|
+
"""Steals forward method to collect inputs and outputs.
|
|
657
|
+
This information is used to infer dynamic shapes and
|
|
658
|
+
export arguments.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
missing: If a named argument (in kwargs) is missing,
|
|
662
|
+
a default value will be taken in this dictionary,
|
|
663
|
+
this is used when after the prefill step, an argument
|
|
664
|
+
disappears (such as `pixel_values`) and another one
|
|
665
|
+
is added (such as `past_key_values`).
|
|
666
|
+
The values are only to infer dynamic shapes and arguments,
|
|
667
|
+
not to run the model.
|
|
668
|
+
|
|
669
|
+
Examples
|
|
670
|
+
--------
|
|
671
|
+
>>> input_observer = InputObserver()
|
|
672
|
+
>>> with input_observer(model):
|
|
673
|
+
>>> model(x1, y1)
|
|
674
|
+
>>> model(x2, y2)
|
|
675
|
+
>>> ep = torch.export.export( # or torch.onnx.export
|
|
676
|
+
>>> model,
|
|
677
|
+
>>> input_observer.infer_arguments(),
|
|
678
|
+
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
|
|
679
|
+
>>> )
|
|
680
|
+
|
|
681
|
+
With LLM:
|
|
285
682
|
|
|
286
|
-
|
|
287
|
-
|
|
683
|
+
>>> input_observer = InputObserver()
|
|
684
|
+
>>> with input_observer(model):
|
|
685
|
+
>>> model.generate(input_ids)
|
|
686
|
+
>>> ep = torch.export.export( # or torch.onnx.export
|
|
687
|
+
>>> model,
|
|
688
|
+
>>> (),
|
|
689
|
+
>>> kwargs=input_observer.infer_arguments(),
|
|
690
|
+
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
|
|
691
|
+
>>> )
|
|
692
|
+
|
|
693
|
+
Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
|
|
694
|
+
:ref:`l-plot-whisper-tiny-export-input-observer`,
|
|
695
|
+
:ref:`l-plot-gemma3-tiny-export-input-observer`.
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
def __init__(self, missing: dict[str, Any] | None = None):
|
|
699
|
+
self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked]
|
|
700
|
+
self.missing = missing or {}
|
|
701
|
+
|
|
702
|
+
def _replaced_method(
|
|
703
|
+
self,
|
|
704
|
+
*args,
|
|
705
|
+
_captured_method: Callable | None = None,
|
|
706
|
+
_store_n_calls: int = 3,
|
|
707
|
+
**kwargs,
|
|
708
|
+
):
|
|
709
|
+
assert _captured_method is not None, "_captured_forward cannot be None"
|
|
288
710
|
assert self.info is not None, "info cannot be None"
|
|
289
711
|
n_stored = len(self.info)
|
|
290
|
-
if n_stored <
|
|
712
|
+
if n_stored < _store_n_calls:
|
|
291
713
|
self.info.add_inputs(args, kwargs)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
714
|
+
begin = time.perf_counter()
|
|
715
|
+
res = _captured_method(*args, **kwargs)
|
|
716
|
+
duration = time.perf_counter() - begin
|
|
717
|
+
if n_stored < _store_n_calls:
|
|
718
|
+
self.info.add_outputs(res, latency=duration)
|
|
295
719
|
return res
|
|
296
720
|
|
|
721
|
+
def num_obs(self) -> int:
|
|
722
|
+
"""Returns the number of stored set if inputs."""
|
|
723
|
+
return 0 if not self.info else len(self.info)
|
|
724
|
+
|
|
297
725
|
@contextlib.contextmanager
|
|
298
|
-
def __call__(
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
model
|
|
306
|
-
|
|
307
|
-
|
|
726
|
+
def __call__(
|
|
727
|
+
self,
|
|
728
|
+
model: torch.nn.Module,
|
|
729
|
+
store_n_calls: int = 3,
|
|
730
|
+
method_name: str = "forward",
|
|
731
|
+
):
|
|
732
|
+
"""Starts collecting inputs and outputs of a specific method.
|
|
733
|
+
The model method is replaced by a new one collecting tensors
|
|
734
|
+
before and after the inner one is called.
|
|
735
|
+
The original method is restored after the collection.
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
model: Model
|
|
739
|
+
store_n_calls: The collection stops after this many calls
|
|
740
|
+
to avoid taking too much memory.
|
|
741
|
+
method_name: Method name to spy on.
|
|
742
|
+
"""
|
|
743
|
+
if not hasattr(model, method_name):
|
|
744
|
+
raise ValueError(f"Model type {model} does not have a method {method_name!r}.")
|
|
745
|
+
captured_method = getattr(model, method_name)
|
|
746
|
+
sig = inspect.signature(captured_method)
|
|
747
|
+
if self.info is None:
|
|
748
|
+
self.info = InputObserverInfo(
|
|
749
|
+
signature_names=list(sig.parameters),
|
|
750
|
+
default_values={
|
|
751
|
+
p.name: p.default
|
|
752
|
+
for p in sig.parameters.values()
|
|
753
|
+
if p.default != inspect.Parameter.empty
|
|
754
|
+
and isinstance(p.default, (int, bool, str, float))
|
|
755
|
+
},
|
|
756
|
+
missing=self.missing,
|
|
308
757
|
)
|
|
758
|
+
n_already_stored = len(self.info)
|
|
759
|
+
lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731
|
|
760
|
+
store_n_calls + n_already_stored
|
|
761
|
+
), **kwargs: self._replaced_method(
|
|
762
|
+
*args, _captured_method=_cm, _store_n_calls=_snc, **kwargs
|
|
309
763
|
)
|
|
764
|
+
|
|
765
|
+
# It may happen than the signature of the forward is used to trigger a preprocessing.
|
|
766
|
+
# This is used in GenerationMixin (transformers):
|
|
767
|
+
# position_ids_key = "decoder_position_ids" if ... else "position_ids"
|
|
768
|
+
# if position_ids_key in set(inspect.signature(self.forward).parameters.keys()):
|
|
769
|
+
lambda_method.__signature__ = sig # type: ignore[attr-defined]
|
|
770
|
+
|
|
771
|
+
setattr(model, method_name, lambda_method)
|
|
772
|
+
|
|
310
773
|
try:
|
|
311
774
|
yield self
|
|
312
775
|
finally:
|
|
313
|
-
model
|
|
776
|
+
setattr(model, method_name, captured_method)
|
|
314
777
|
|
|
315
778
|
def _check_captured(self):
|
|
316
779
|
if self.info is None:
|
|
317
780
|
raise RuntimeError("No inputs were captured.")
|
|
318
781
|
|
|
319
|
-
def infer_dynamic_shapes(
|
|
782
|
+
def infer_dynamic_shapes(
|
|
783
|
+
self, set_batch_dimension_for: set[int | str] | bool | None = None
|
|
784
|
+
) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
|
|
785
|
+
"""
|
|
786
|
+
Infers dynamic shapes. Most of the time, models do support a batch dimension
|
|
787
|
+
but this batch dimension has the same value for every input sample.
|
|
788
|
+
Instead of running inference on new samples, argument `set_batch_dimension_for`
|
|
789
|
+
can be used to tell the first dimension is a dynamic dimension for a particular
|
|
790
|
+
set of inputs referenced by their name (str) or their position (int).
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
set_batch_dimension_for (set[int | str] | None): A set of input
|
|
794
|
+
identifiers (by position as ``int`` or by name as ``str``) for
|
|
795
|
+
which the first dimension should be treated as a dynamic batch
|
|
796
|
+
dimension. If ``None``, no dimensions are explicitly marked as
|
|
797
|
+
dynamic.
|
|
798
|
+
"""
|
|
320
799
|
self._check_captured()
|
|
321
800
|
assert self.info is not None # missed by type checking
|
|
322
|
-
return self.info.infer_dynamic_shapes()
|
|
801
|
+
return self.info.infer_dynamic_shapes(set_batch_dimension_for=set_batch_dimension_for)
|
|
323
802
|
|
|
324
803
|
def infer_arguments(
|
|
325
|
-
self,
|
|
326
|
-
|
|
804
|
+
self,
|
|
805
|
+
index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None,
|
|
806
|
+
flat: bool = False,
|
|
807
|
+
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
808
|
+
"""Infers arguments based on the collected tensors.
|
|
809
|
+
|
|
810
|
+
Args:
|
|
811
|
+
index_or_args_or_kwargs: If missing, the method selects one set of inputs
|
|
812
|
+
among the available ones, usually this inputs containing
|
|
813
|
+
the set of stored inputs with the highest number of tensors.
|
|
814
|
+
The then replaces None values and missing tensors by empty tensors.
|
|
815
|
+
If not missing, it can be an integer to fetch one of the stored set
|
|
816
|
+
or some inputs.
|
|
817
|
+
flat: If True, it returns a flattened list of tensors,
|
|
818
|
+
if False, it returns a tuple or a dictionary preserving
|
|
819
|
+
the nested structures.
|
|
820
|
+
|
|
821
|
+
Returns:
|
|
822
|
+
Inferred arguments, every optional tensor is replaced by a empty tensor.
|
|
823
|
+
"""
|
|
327
824
|
self._check_captured()
|
|
328
825
|
assert self.info is not None # missed by type checking
|
|
329
|
-
|
|
826
|
+
index_or_candidate: int | InputCandidate | None = None
|
|
827
|
+
if index_or_args_or_kwargs is None or isinstance(index_or_args_or_kwargs, int):
|
|
828
|
+
index_or_candidate = index_or_args_or_kwargs
|
|
829
|
+
else:
|
|
830
|
+
if isinstance(index_or_args_or_kwargs, tuple):
|
|
831
|
+
index_or_candidate = InputCandidate(
|
|
832
|
+
args=index_or_args_or_kwargs, kwargs={}, clone=False, cst_kwargs={}
|
|
833
|
+
)
|
|
834
|
+
elif isinstance(index_or_args_or_kwargs, dict):
|
|
835
|
+
index_or_candidate = InputCandidate(
|
|
836
|
+
args=(),
|
|
837
|
+
kwargs={
|
|
838
|
+
k: v
|
|
839
|
+
for k, v in index_or_args_or_kwargs.items()
|
|
840
|
+
if k not in self.info.default_values
|
|
841
|
+
},
|
|
842
|
+
clone=False,
|
|
843
|
+
cst_kwargs={
|
|
844
|
+
k: v
|
|
845
|
+
for k, v in index_or_args_or_kwargs.items()
|
|
846
|
+
if k in self.info.default_values
|
|
847
|
+
},
|
|
848
|
+
)
|
|
849
|
+
else:
|
|
850
|
+
raise ValueError(
|
|
851
|
+
f"Unexpected type {type(index_or_args_or_kwargs)} "
|
|
852
|
+
f"for index_or_args_or_kwargs."
|
|
853
|
+
)
|
|
854
|
+
self.info.align_inputs_none_values()
|
|
855
|
+
# type checking
|
|
856
|
+
assert self.info._best_candidate is not None
|
|
857
|
+
assert self.info._captured_inputs is not None
|
|
858
|
+
index_or_candidate.align_with(
|
|
859
|
+
self.info._best_candidate,
|
|
860
|
+
self.info._captured_inputs,
|
|
861
|
+
self.info.signature_names,
|
|
862
|
+
)
|
|
863
|
+
return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat)
|
|
864
|
+
|
|
865
|
+
def check_discrepancies(
|
|
866
|
+
self,
|
|
867
|
+
onnx_model: str | onnx.ModelProto,
|
|
868
|
+
atol: float = 1e-4,
|
|
869
|
+
rtol: float = 0.1,
|
|
870
|
+
hist=(0.1, 0.01),
|
|
871
|
+
progress_bar: bool = False,
|
|
872
|
+
include_io: bool = True,
|
|
873
|
+
) -> list[dict[str, str | int | float | bool]]:
|
|
874
|
+
"""Computes the discrepancies between the saved inputs and outputs
|
|
875
|
+
with the saved onnx model.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
onnx_model:
|
|
879
|
+
ONNX Model to verify.
|
|
880
|
+
atol:
|
|
881
|
+
Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
|
|
882
|
+
rtol:
|
|
883
|
+
Relative tolerance.
|
|
884
|
+
hist:
|
|
885
|
+
Thresholds, the function determines the number of discrepancies
|
|
886
|
+
above these thresholds.
|
|
887
|
+
progress_bar:
|
|
888
|
+
Shows a progress bar (requires :epkg:`tqdm`).
|
|
889
|
+
include_io:
|
|
890
|
+
Shows inputs/outputs shapes in the summary
|
|
891
|
+
returned by this function.
|
|
892
|
+
|
|
893
|
+
Returns:
|
|
894
|
+
A list of dictionaries, ready to be consumed by a dataframe.
|
|
895
|
+
|
|
896
|
+
The function catches exceptions, it shows the error in the returned
|
|
897
|
+
summary.
|
|
898
|
+
"""
|
|
899
|
+
sess = OnnxruntimeEvaluator(onnx_model, whole=True)
|
|
900
|
+
input_names = sess.input_names
|
|
901
|
+
self._check_captured()
|
|
902
|
+
# type checking
|
|
903
|
+
assert self.info is not None
|
|
904
|
+
assert self.info.inputs is not None
|
|
905
|
+
assert self.info.flat_outputs is not None
|
|
906
|
+
assert self.info.latencies is not None
|
|
907
|
+
io_sets = list(zip(self.info.inputs, self.info.flat_outputs, self.info.latencies))
|
|
908
|
+
if progress_bar:
|
|
909
|
+
from tqdm import tqdm
|
|
910
|
+
|
|
911
|
+
loop = tqdm(io_sets)
|
|
912
|
+
else:
|
|
913
|
+
loop = io_sets
|
|
914
|
+
lhist = list(hist)
|
|
915
|
+
data: list[dict[str, Any]] = []
|
|
916
|
+
for inputs, outputs, latency in loop:
|
|
917
|
+
# type checking
|
|
918
|
+
assert inputs.aligned_flat_list is not None
|
|
919
|
+
if len(input_names) != len(inputs.aligned_flat_list):
|
|
920
|
+
raise RuntimeError(
|
|
921
|
+
f"There are ({len(inputs.aligned_flat_list)}) "
|
|
922
|
+
f"tensors but the model expects {len(input_names)}."
|
|
923
|
+
)
|
|
924
|
+
n_none = sum([t is None for t in inputs.aligned_flat_list])
|
|
925
|
+
n_empty = sum([t is None or t.numel() == 0 for t in inputs.aligned_flat_list])
|
|
926
|
+
|
|
927
|
+
feeds = dict(zip(input_names, self.info.infer_arguments(inputs, flat=True)))
|
|
928
|
+
|
|
929
|
+
begin = time.perf_counter()
|
|
930
|
+
try:
|
|
931
|
+
ort_outputs = sess.run(None, feeds)
|
|
932
|
+
error = None
|
|
933
|
+
except Exception as e:
|
|
934
|
+
error = str(e)
|
|
935
|
+
ort_outputs = None
|
|
936
|
+
|
|
937
|
+
duration = time.perf_counter() - begin
|
|
938
|
+
if error:
|
|
939
|
+
diff: dict[str, str | int | float | bool] = dict(error=error, SUCCESS=False)
|
|
940
|
+
else:
|
|
941
|
+
# The last output may be empty and torch could skip it.
|
|
942
|
+
if isinstance(outputs, list) and isinstance(ort_outputs, list):
|
|
943
|
+
while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0:
|
|
944
|
+
ort_outputs.pop()
|
|
945
|
+
diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment]
|
|
946
|
+
if "rep" in diff and isinstance(diff["rep"], dict):
|
|
947
|
+
diff.update(diff["rep"])
|
|
948
|
+
del diff["rep"]
|
|
949
|
+
diff["SUCCESS"] = (
|
|
950
|
+
isinstance(diff["abs"], float)
|
|
951
|
+
and isinstance(diff["rel"], float)
|
|
952
|
+
and diff["abs"] < atol
|
|
953
|
+
and diff["rel"] < rtol
|
|
954
|
+
)
|
|
955
|
+
diff.update(
|
|
956
|
+
dict(
|
|
957
|
+
index=len(diff),
|
|
958
|
+
duration_torch=latency,
|
|
959
|
+
ort_duration=duration,
|
|
960
|
+
n_inputs=len(input_names),
|
|
961
|
+
n_none=n_none,
|
|
962
|
+
n_empty=n_empty,
|
|
963
|
+
)
|
|
964
|
+
)
|
|
965
|
+
if include_io:
|
|
966
|
+
diff["inputs"] = string_type(feeds, with_shape=True)
|
|
967
|
+
diff["outputs_torch"] = string_type(outputs, with_shape=True)
|
|
968
|
+
diff["outputs_ort"] = string_type(ort_outputs, with_shape=True)
|
|
969
|
+
data.append(diff)
|
|
970
|
+
return data
|