onnx-diagnostic 0.8.9__py3-none-any.whl → 0.8.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
- onnx_diagnostic/export/api.py +24 -12
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +6 -6
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +0 -1
- onnx_diagnostic/helpers/torch_helper.py +8 -9
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +329 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +5 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +14 -13
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +7 -12
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -0
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +48 -46
- {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.9.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Any, Callable, Sequence
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def flatten_unflatten_for_dynamic_shapes(
|
|
8
|
+
obj: Any,
|
|
9
|
+
use_dict: bool = True,
|
|
10
|
+
change_function: Callable[[torch.Tensor], Any] | None = None,
|
|
11
|
+
) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Returns the object in a different structure similar to what
|
|
14
|
+
the definition of the dynamic shapes should use.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
obj:
|
|
18
|
+
object from a custom class
|
|
19
|
+
use_dict:
|
|
20
|
+
closer to the original result but
|
|
21
|
+
:func:`torch.export.export` only considers the values,
|
|
22
|
+
the context gives the dictionary keys but it is not expressed
|
|
23
|
+
in the dynamic shapes, these specifications seems to be different
|
|
24
|
+
for the strict and non strict mode. It also preserves tuple.
|
|
25
|
+
change_function:
|
|
26
|
+
to modify the tensor in the structure itself,
|
|
27
|
+
like replace them by a shape
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
the serialized object
|
|
31
|
+
"""
|
|
32
|
+
if isinstance(obj, torch.Tensor):
|
|
33
|
+
return change_function(obj) if change_function else obj
|
|
34
|
+
flat, spec = torch.utils._pytree.tree_flatten(obj)
|
|
35
|
+
start = 0
|
|
36
|
+
end = 0
|
|
37
|
+
subtrees = []
|
|
38
|
+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
39
|
+
end += subspec.num_leaves
|
|
40
|
+
value = subspec.unflatten(flat[start:end])
|
|
41
|
+
value = flatten_unflatten_for_dynamic_shapes(
|
|
42
|
+
value, use_dict=use_dict, change_function=change_function
|
|
43
|
+
)
|
|
44
|
+
subtrees.append(value)
|
|
45
|
+
start = end
|
|
46
|
+
if use_dict:
|
|
47
|
+
if spec.type is dict:
|
|
48
|
+
# This is a dictionary.
|
|
49
|
+
return dict(zip(spec.context, subtrees))
|
|
50
|
+
if spec.type is tuple:
|
|
51
|
+
return tuple(subtrees)
|
|
52
|
+
if spec.type is list:
|
|
53
|
+
return list(subtrees)
|
|
54
|
+
if spec.type is None and not subtrees:
|
|
55
|
+
return None
|
|
56
|
+
if spec.context:
|
|
57
|
+
# This is a custom class with attributes.
|
|
58
|
+
# It is returned as a list.
|
|
59
|
+
return list(subtrees)
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Unable to interpret spec type {spec.type} "
|
|
62
|
+
f"(type is {type(spec.type)}, context is {spec.context}), "
|
|
63
|
+
f"spec={spec}, subtrees={subtrees}"
|
|
64
|
+
)
|
|
65
|
+
# This is a list.
|
|
66
|
+
return subtrees
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]:
|
|
70
|
+
"""
|
|
71
|
+
Returns the list of dynamic dimensions given a list of shapes
|
|
72
|
+
corresponding to the same tensor.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
shape_list:
|
|
76
|
+
list of shapes, they must all have the same length
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
list of dynamic dimensions
|
|
80
|
+
"""
|
|
81
|
+
unique_ranks = {len(shape) for shape in shape_list}
|
|
82
|
+
torch._check(
|
|
83
|
+
len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank"
|
|
84
|
+
)
|
|
85
|
+
rank = unique_ranks.pop()
|
|
86
|
+
dynamic = []
|
|
87
|
+
for i in range(rank):
|
|
88
|
+
dims = [shape[i] for shape in shape_list]
|
|
89
|
+
if len(set(dims)) > 1:
|
|
90
|
+
dynamic.append(i)
|
|
91
|
+
return dynamic
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class InputObserverInfo:
|
|
95
|
+
def __init__(self, signature: inspect.Signature):
|
|
96
|
+
# pyrefly: ignore
|
|
97
|
+
self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
|
|
98
|
+
self.flat_inputs: list[list[torch.Tensor | None]] = []
|
|
99
|
+
# pyrefly: ignore
|
|
100
|
+
self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
|
|
101
|
+
self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = []
|
|
102
|
+
self.signature = signature
|
|
103
|
+
|
|
104
|
+
self._max_args: tuple[Any, torch.Tensor] | None = None
|
|
105
|
+
self._max_kwargs: dict[str, torch.Tensor] | None = None
|
|
106
|
+
|
|
107
|
+
def __len__(self) -> int:
|
|
108
|
+
return len(self.flat_inputs)
|
|
109
|
+
|
|
110
|
+
def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
|
|
111
|
+
kwargs = {
|
|
112
|
+
k: v
|
|
113
|
+
for k, v in kwargs.items()
|
|
114
|
+
if v is not None and not isinstance(v, (int, float, bool))
|
|
115
|
+
}
|
|
116
|
+
flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
117
|
+
self.inputs_specs.append(spec)
|
|
118
|
+
cloned = [
|
|
119
|
+
(None if not isinstance(t, torch.Tensor) else t.clone().detach())
|
|
120
|
+
for t in flat_args
|
|
121
|
+
]
|
|
122
|
+
self.flat_inputs.append(cloned)
|
|
123
|
+
|
|
124
|
+
cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)
|
|
125
|
+
if self._max_args is None or len(cloned_args) > len(self._max_args):
|
|
126
|
+
self._max_args = cloned_args
|
|
127
|
+
if self._max_kwargs is None or len(cloned_kwargs) > len(self._max_kwargs):
|
|
128
|
+
self._max_kwargs = cloned_kwargs
|
|
129
|
+
|
|
130
|
+
def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
|
|
131
|
+
flat_res, spec = torch.utils._pytree.tree_flatten(res)
|
|
132
|
+
self.outputs_specs.append(spec)
|
|
133
|
+
self.flat_outputs.append([t.clone().detach() for t in flat_res])
|
|
134
|
+
|
|
135
|
+
def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]:
|
|
136
|
+
# Let's compute the sizes of each independently.
|
|
137
|
+
if not self.flat_inputs or self._max_args is None or self._max_kwargs is None:
|
|
138
|
+
raise RuntimeError("No inputs were captured.")
|
|
139
|
+
arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args]
|
|
140
|
+
kwarg_sizes = {
|
|
141
|
+
k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items()
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Let's reprocess everything.
|
|
145
|
+
captured_inputs: dict[int | str, int] = {}
|
|
146
|
+
new_flat_inputs = []
|
|
147
|
+
for args_kwargs, spec in zip(self.flat_inputs, self.inputs_specs):
|
|
148
|
+
args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
|
|
149
|
+
if len(set(kwargs) | set(self._max_kwargs)) > len(self._max_kwargs):
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
"At least one call to the observed model "
|
|
152
|
+
"must contain all the named arguments."
|
|
153
|
+
)
|
|
154
|
+
flat = []
|
|
155
|
+
for i in range(len(self._max_args)):
|
|
156
|
+
if i < len(args):
|
|
157
|
+
ts = torch.utils._pytree.tree_flatten(args[i])[0]
|
|
158
|
+
if i in captured_inputs and captured_inputs[i] != len(ts):
|
|
159
|
+
raise RuntimeError(
|
|
160
|
+
f"Positional argument {i} has {len(ts)} tensors "
|
|
161
|
+
f"but previously got {captured_inputs[i]} tensors. "
|
|
162
|
+
f"Inference is impossible in that case."
|
|
163
|
+
)
|
|
164
|
+
captured_inputs[i] = len(ts)
|
|
165
|
+
flat.extend(ts)
|
|
166
|
+
else:
|
|
167
|
+
flat.extend([None for _ in range(arg_sizes[i])])
|
|
168
|
+
for k in self._max_kwargs:
|
|
169
|
+
if k in kwargs:
|
|
170
|
+
ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
|
|
171
|
+
if k in captured_inputs and captured_inputs[k] != len(ts):
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Named argument {k!r} has {len(ts)} tensors "
|
|
174
|
+
f"but previously got {captured_inputs[k]} tensors. "
|
|
175
|
+
f"Inference is impossible in that case."
|
|
176
|
+
)
|
|
177
|
+
captured_inputs[k] = len(ts)
|
|
178
|
+
flat.extend(ts)
|
|
179
|
+
else:
|
|
180
|
+
flat.extend([None for _ in range(kwarg_sizes[k])])
|
|
181
|
+
new_flat_inputs.append(flat)
|
|
182
|
+
return new_flat_inputs
|
|
183
|
+
|
|
184
|
+
def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
|
|
185
|
+
flat_inputs = self.build_inputs_completed_with_none_values()
|
|
186
|
+
# This is already checked by build_inputs_completed_with_none_values
|
|
187
|
+
# but this is not always well captured by tools checking types.
|
|
188
|
+
assert self._max_args is not None and self._max_kwargs is not None
|
|
189
|
+
if len({len(flat) for flat in flat_inputs}) != 1:
|
|
190
|
+
raise NotImplementedError(
|
|
191
|
+
"infer_dynamic_shapes is not implemented "
|
|
192
|
+
"when the number of input tensors are not the same."
|
|
193
|
+
)
|
|
194
|
+
shape_lists = [
|
|
195
|
+
[(None if t is None else t.shape) for t in tensors] for tensors in flat_inputs
|
|
196
|
+
]
|
|
197
|
+
n_tensors = len(shape_lists[0])
|
|
198
|
+
dynamic_shapes = [
|
|
199
|
+
infer_dynamic_dimensions(
|
|
200
|
+
[s for s in [shapes[index] for shapes in shape_lists] if s is not None]
|
|
201
|
+
)
|
|
202
|
+
for index in range(n_tensors)
|
|
203
|
+
]
|
|
204
|
+
cst = torch.export.Dim.DYNAMIC
|
|
205
|
+
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
|
|
206
|
+
if len(flat_dynamic_shapes) == len(self._max_args) + len(self._max_kwargs):
|
|
207
|
+
# It means forward method is called with tensors only.
|
|
208
|
+
if not self._max_kwargs:
|
|
209
|
+
# only positional arguments
|
|
210
|
+
return tuple(flat_dynamic_shapes)
|
|
211
|
+
if not self._max_args:
|
|
212
|
+
# only named arguments
|
|
213
|
+
return dict(zip(list(self._max_kwargs), flat_dynamic_shapes))
|
|
214
|
+
# positional arguments needs to be moved to the named arguments
|
|
215
|
+
n_args = len(self._max_args)
|
|
216
|
+
pos_names = list(self.signature.parameters)[:n_args]
|
|
217
|
+
return {
|
|
218
|
+
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
|
|
219
|
+
**dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
# nested types, here comes the fun part because the shapes cannot be unflattened,
|
|
223
|
+
# custom classes must appear in their flattened shape.
|
|
224
|
+
# This does not work in all cases but every time every available argument is flattened
|
|
225
|
+
# with the same number of tensors. The function does not check
|
|
226
|
+
# if that assumption is true.
|
|
227
|
+
flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
|
|
228
|
+
(self._max_args, self._max_kwargs)
|
|
229
|
+
)
|
|
230
|
+
torch._check(
|
|
231
|
+
len(flat_inputs) == len(flat_dynamic_shapes),
|
|
232
|
+
(
|
|
233
|
+
f"Length mismatch len(flat_inputs)={len(flat_inputs)}, "
|
|
234
|
+
f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
|
|
235
|
+
),
|
|
236
|
+
)
|
|
237
|
+
mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)}
|
|
238
|
+
ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes(
|
|
239
|
+
(self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)]
|
|
240
|
+
)
|
|
241
|
+
if not ds_kwargs:
|
|
242
|
+
return tuple(ds_args)
|
|
243
|
+
if not ds_args:
|
|
244
|
+
return tuple(ds_kwargs)
|
|
245
|
+
pos_names = list(self.signature.parameters)[: len(ds_args)]
|
|
246
|
+
return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
|
|
247
|
+
|
|
248
|
+
def infer_arguments(
|
|
249
|
+
self, index: int | None = None
|
|
250
|
+
) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
251
|
+
# This is already checked by build_inputs_completed_with_none_values
|
|
252
|
+
# but this is not always well captured by tools checking types.
|
|
253
|
+
assert self._max_args is not None and self._max_kwargs is not None
|
|
254
|
+
candidate = None
|
|
255
|
+
if index is None:
|
|
256
|
+
for i, (args_kwargs, spec) in enumerate(zip(self.flat_inputs, self.inputs_specs)):
|
|
257
|
+
args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
|
|
258
|
+
if len(args) == len(self._max_args) and len(kwargs) == len(self._max_kwargs):
|
|
259
|
+
index = i
|
|
260
|
+
candidate = args, kwargs
|
|
261
|
+
break
|
|
262
|
+
if index is not None:
|
|
263
|
+
# found one available set.
|
|
264
|
+
args, kwargs = candidate or torch.utils._pytree.tree_unflatten(
|
|
265
|
+
self.flat_inputs[index], self.inputs_specs[index]
|
|
266
|
+
)
|
|
267
|
+
if not kwargs:
|
|
268
|
+
return args
|
|
269
|
+
if not args:
|
|
270
|
+
return kwargs
|
|
271
|
+
# We need to move args to kwargs
|
|
272
|
+
pos_names = list(self.signature.parameters)[: len(args)]
|
|
273
|
+
return {**dict(zip(pos_names, args)), **kwargs}
|
|
274
|
+
|
|
275
|
+
raise NotImplementedError(
|
|
276
|
+
"We could not find a good set of inputs/outputs. "
|
|
277
|
+
"We need to replace none by empty tensors."
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class InputObserver:
|
|
282
|
+
def __init__(self, store_n_calls: int = 3):
|
|
283
|
+
self.store_n_calls = store_n_calls
|
|
284
|
+
self.info: InputObserverInfo | None = None
|
|
285
|
+
|
|
286
|
+
def _forward_captured(self, *args, _captured_forward=None, **kwargs):
|
|
287
|
+
assert _captured_forward is not None, "_captured_forward cannot be None"
|
|
288
|
+
assert self.info is not None, "info cannot be None"
|
|
289
|
+
n_stored = len(self.info)
|
|
290
|
+
if n_stored < self.store_n_calls:
|
|
291
|
+
self.info.add_inputs(args, kwargs)
|
|
292
|
+
res = _captured_forward(*args, **kwargs)
|
|
293
|
+
if n_stored < self.store_n_calls:
|
|
294
|
+
self.info.add_outputs(res)
|
|
295
|
+
return res
|
|
296
|
+
|
|
297
|
+
@contextlib.contextmanager
|
|
298
|
+
def __call__(self, model: torch.nn.Module):
|
|
299
|
+
if self.info is not None:
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
"This class was already used to capture a model. Please create a new one."
|
|
302
|
+
)
|
|
303
|
+
self.info = InputObserverInfo(signature=inspect.signature(model.forward))
|
|
304
|
+
forward_method = model.forward
|
|
305
|
+
model.forward = (
|
|
306
|
+
lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured(
|
|
307
|
+
*args, _captured_forward=_captured_forward, **kwargs
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
try:
|
|
311
|
+
yield self
|
|
312
|
+
finally:
|
|
313
|
+
model.forward = forward_method
|
|
314
|
+
|
|
315
|
+
def _check_captured(self):
|
|
316
|
+
if self.info is None:
|
|
317
|
+
raise RuntimeError("No inputs were captured.")
|
|
318
|
+
|
|
319
|
+
def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
|
|
320
|
+
self._check_captured()
|
|
321
|
+
assert self.info is not None # missed by type checking
|
|
322
|
+
return self.info.infer_dynamic_shapes()
|
|
323
|
+
|
|
324
|
+
def infer_arguments(
|
|
325
|
+
self, index: int | None = None
|
|
326
|
+
) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
327
|
+
self._check_captured()
|
|
328
|
+
assert self.info is not None # missed by type checking
|
|
329
|
+
return self.info.infer_arguments(index=index)
|
|
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
|
|
|
34
34
|
from .report_results_comparison import ReportResultComparison
|
|
35
35
|
from .evaluator import ExtendedReferenceEvaluator
|
|
36
36
|
|
|
37
|
-
|
|
38
37
|
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
39
38
|
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
40
39
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Tuple, Union
|
|
2
|
-
|
|
1
|
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
3
2
|
|
|
4
3
|
ReportKeyNameType = Union[str, Tuple[str, int, str]]
|
|
5
4
|
ReportKeyValueType = Tuple[int, Tuple[int, ...]]
|
|
@@ -14,6 +13,7 @@ class ReportResultComparison:
|
|
|
14
13
|
:param tensors: tensor
|
|
15
14
|
"""
|
|
16
15
|
|
|
16
|
+
# pyrefly: ignore[unknown-name]
|
|
17
17
|
def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
|
|
18
18
|
from ..helpers.onnx_helper import dtype_to_tensor_dtype
|
|
19
19
|
from ..helpers import max_diff, string_type
|
|
@@ -25,7 +25,9 @@ class ReportResultComparison:
|
|
|
25
25
|
self.max_diff = max_diff
|
|
26
26
|
self.tensors = tensors
|
|
27
27
|
self._build_mapping()
|
|
28
|
+
self.unique_run_names: Set[str] = set()
|
|
28
29
|
|
|
30
|
+
# pyrefly: ignore[unknown-name]
|
|
29
31
|
def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
|
|
30
32
|
"Returns a key for a tensor, (onnx dtype, shape)."
|
|
31
33
|
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
|
|
@@ -59,12 +61,15 @@ class ReportResultComparison:
|
|
|
59
61
|
for k, v in self.value.items():
|
|
60
62
|
(i_run, run_name), ref_name = k
|
|
61
63
|
d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
|
|
64
|
+
# pyrefly: ignore[no-matching-overload]
|
|
62
65
|
d.update(v)
|
|
63
66
|
rows.append(d)
|
|
64
67
|
return rows
|
|
65
68
|
|
|
66
69
|
def report(
|
|
67
|
-
self,
|
|
70
|
+
self,
|
|
71
|
+
# pyrefly: ignore[unknown-name]
|
|
72
|
+
outputs: Dict[str, "torch.Tensor"], # noqa: F821
|
|
68
73
|
) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
|
|
69
74
|
"""
|
|
70
75
|
For every tensor in outputs, compares it to every tensor held by
|
|
@@ -79,6 +84,7 @@ class ReportResultComparison:
|
|
|
79
84
|
key = self.key(tensor)
|
|
80
85
|
if key not in self.mapping:
|
|
81
86
|
continue
|
|
87
|
+
# pyrefly: ignore[unknown-name]
|
|
82
88
|
cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
|
|
83
89
|
for held_key in self.mapping[key]:
|
|
84
90
|
t2 = self.tensors[held_key]
|
|
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
|
|
|
63
63
|
* `functions`: local functions
|
|
64
64
|
|
|
65
65
|
The class is not multithreaded. `runtime_info` gets updated
|
|
66
|
-
by the
|
|
66
|
+
by the class. The list of available kernels is returned by function
|
|
67
67
|
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
|
|
68
68
|
Example:
|
|
69
69
|
|
|
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
|
|
|
494
494
|
r = self.runtime_info[k]
|
|
495
495
|
r.set_value(
|
|
496
496
|
torch_ops.OpRunTensor(
|
|
497
|
+
# pyrefly: ignore[missing-attribute]
|
|
497
498
|
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
|
|
498
499
|
is_constant=False,
|
|
500
|
+
# pyrefly: ignore[missing-attribute]
|
|
499
501
|
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
500
502
|
)
|
|
501
503
|
)
|
|
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
|
|
|
524
526
|
f"for kernel {type(kernel)}."
|
|
525
527
|
)
|
|
526
528
|
for name, t in zip(kernel.output, res):
|
|
529
|
+
# pyrefly: ignore[bad-argument-type]
|
|
527
530
|
self.runtime_info[name].set_value(t)
|
|
528
531
|
if self.verbose:
|
|
529
532
|
for name in kernel.output:
|
|
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
|
|
|
644
647
|
f"for kernel {type(kernel)}."
|
|
645
648
|
)
|
|
646
649
|
for name, t in zip(kernel.output, res):
|
|
650
|
+
# pyrefly: ignore[bad-argument-type]
|
|
647
651
|
self.runtime_info[name].set_value(t)
|
|
648
652
|
else:
|
|
649
653
|
assert isinstance(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
2
2
|
import onnx
|
|
3
3
|
import torch
|
|
4
|
-
from ...
|
|
4
|
+
from ...typing import TensorLike
|
|
5
5
|
from ...helpers import string_type
|
|
6
6
|
from ...helpers.torch_helper import to_tensor
|
|
7
7
|
|
|
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
|
|
|
149
149
|
) -> "OpRunSequence":
|
|
150
150
|
"Inserts a value at a given position."
|
|
151
151
|
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
|
|
152
|
-
new_seq = OpRunSequence()
|
|
152
|
+
new_seq = OpRunSequence() # type: ignore[abstract]
|
|
153
153
|
seq = self.sequence.copy()
|
|
154
154
|
new_seq.sequence = seq
|
|
155
155
|
if position is None:
|
|
@@ -314,9 +314,7 @@ class OpRunKernel:
|
|
|
314
314
|
|
|
315
315
|
|
|
316
316
|
class OpRunFunction(OpRunKernel):
|
|
317
|
-
"""
|
|
318
|
-
Defines a kernel based on a local functions.
|
|
319
|
-
"""
|
|
317
|
+
"""Defines a kernel based on a local functions."""
|
|
320
318
|
|
|
321
319
|
def __init__(
|
|
322
320
|
self,
|
|
@@ -305,7 +305,7 @@ def serialization_functions(
|
|
|
305
305
|
|
|
306
306
|
|
|
307
307
|
def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
308
|
-
"""Undo the registration."""
|
|
308
|
+
"""Undo the registration for a class."""
|
|
309
309
|
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
|
|
310
310
|
if cls in torch.fx._pytree.SUPPORTED_NODES:
|
|
311
311
|
del torch.fx._pytree.SUPPORTED_NODES[cls]
|
|
@@ -333,6 +333,10 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
|
333
333
|
|
|
334
334
|
|
|
335
335
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
336
|
+
"""
|
|
337
|
+
Undo the registration made by
|
|
338
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization`.
|
|
339
|
+
"""
|
|
336
340
|
cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
|
|
337
341
|
for cls in cls_ensemble:
|
|
338
342
|
if undo.get(cls.__name__, False):
|
|
@@ -7,10 +7,10 @@ import transformers
|
|
|
7
7
|
|
|
8
8
|
def patched__compute_dynamic_ntk_parameters(
|
|
9
9
|
config: Optional[transformers.PretrainedConfig] = None,
|
|
10
|
-
device: Optional[
|
|
10
|
+
device: Optional[torch.device] = None,
|
|
11
11
|
seq_len: Optional[int] = None,
|
|
12
12
|
**rope_kwargs,
|
|
13
|
-
) -> Tuple[
|
|
13
|
+
) -> Tuple[torch.Tensor, float]:
|
|
14
14
|
"""
|
|
15
15
|
manual patch:
|
|
16
16
|
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
@@ -524,13 +524,16 @@ class patched_ShapeEnv:
|
|
|
524
524
|
|
|
525
525
|
transmute_into_runtime_assert = False
|
|
526
526
|
|
|
527
|
+
backed_var_to_val = getattr(
|
|
528
|
+
self, "backed_var_to_val", getattr(self, "var_to_val", {})
|
|
529
|
+
)
|
|
527
530
|
concrete_val = None
|
|
528
|
-
if not (expr.free_symbols <=
|
|
531
|
+
if not (expr.free_symbols <= backed_var_to_val.keys()):
|
|
529
532
|
# TODO: dedupe this with _maybe_evaluate_static
|
|
530
533
|
# Attempt to eliminate the unbacked SymInt
|
|
531
534
|
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
|
532
535
|
assert new_expr is not None
|
|
533
|
-
if not (new_expr.free_symbols <=
|
|
536
|
+
if not (new_expr.free_symbols <= backed_var_to_val.keys()):
|
|
534
537
|
ok = False
|
|
535
538
|
|
|
536
539
|
# fallback_value is set when guard_or_true or guard_or_false are used.
|
|
@@ -541,17 +544,15 @@ class patched_ShapeEnv:
|
|
|
541
544
|
# oblivious_var_to_val will be defined iff we have sizes
|
|
542
545
|
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
543
546
|
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
544
|
-
var_to_val = getattr(
|
|
545
|
-
self,
|
|
546
|
-
"unbacked_var_to_val",
|
|
547
|
-
getattr(self, "oblivious_var_to_val", False),
|
|
548
|
-
)
|
|
549
547
|
if (
|
|
550
|
-
|
|
551
|
-
and
|
|
548
|
+
backed_var_to_val
|
|
549
|
+
and getattr(self, "real_tensor_prop_unbacked_vals", True)
|
|
550
|
+
and not (
|
|
551
|
+
correct_hint := orig_expr.xreplace(backed_var_to_val)
|
|
552
|
+
).free_symbols
|
|
552
553
|
and not (
|
|
553
554
|
counterfactual_hint := orig_expr.xreplace(
|
|
554
|
-
{k: max(2, v) for k, v in
|
|
555
|
+
{k: max(2, v) for k, v in backed_var_to_val.items()}
|
|
555
556
|
)
|
|
556
557
|
).free_symbols
|
|
557
558
|
and correct_hint == counterfactual_hint
|
|
@@ -574,10 +575,10 @@ class patched_ShapeEnv:
|
|
|
574
575
|
# and if they pass we add a runtime assertions and continue.
|
|
575
576
|
if (
|
|
576
577
|
not ok
|
|
577
|
-
and
|
|
578
|
+
and backed_var_to_val
|
|
578
579
|
and not (
|
|
579
|
-
unsound_result := orig_expr.xreplace(
|
|
580
|
-
|
|
580
|
+
unsound_result := orig_expr.xreplace(backed_var_to_val).xreplace(
|
|
581
|
+
backed_var_to_val
|
|
581
582
|
)
|
|
582
583
|
).free_symbols
|
|
583
584
|
):
|