onnx-diagnostic 0.8.10__py3-none-any.whl → 0.8.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
- onnx_diagnostic/export/api.py +2 -4
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +6 -6
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +0 -1
- onnx_diagnostic/helpers/torch_helper.py +8 -9
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +329 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +45 -43
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Any, Callable, Sequence
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def flatten_unflatten_for_dynamic_shapes(
|
|
8
|
+
obj: Any,
|
|
9
|
+
use_dict: bool = True,
|
|
10
|
+
change_function: Callable[[torch.Tensor], Any] | None = None,
|
|
11
|
+
) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Returns the object in a different structure similar to what
|
|
14
|
+
the definition of the dynamic shapes should use.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
obj:
|
|
18
|
+
object from a custom class
|
|
19
|
+
use_dict:
|
|
20
|
+
closer to the original result but
|
|
21
|
+
:func:`torch.export.export` only considers the values,
|
|
22
|
+
the context gives the dictionary keys but it is not expressed
|
|
23
|
+
in the dynamic shapes, these specifications seems to be different
|
|
24
|
+
for the strict and non strict mode. It also preserves tuple.
|
|
25
|
+
change_function:
|
|
26
|
+
to modify the tensor in the structure itself,
|
|
27
|
+
like replace them by a shape
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
the serialized object
|
|
31
|
+
"""
|
|
32
|
+
if isinstance(obj, torch.Tensor):
|
|
33
|
+
return change_function(obj) if change_function else obj
|
|
34
|
+
flat, spec = torch.utils._pytree.tree_flatten(obj)
|
|
35
|
+
start = 0
|
|
36
|
+
end = 0
|
|
37
|
+
subtrees = []
|
|
38
|
+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
39
|
+
end += subspec.num_leaves
|
|
40
|
+
value = subspec.unflatten(flat[start:end])
|
|
41
|
+
value = flatten_unflatten_for_dynamic_shapes(
|
|
42
|
+
value, use_dict=use_dict, change_function=change_function
|
|
43
|
+
)
|
|
44
|
+
subtrees.append(value)
|
|
45
|
+
start = end
|
|
46
|
+
if use_dict:
|
|
47
|
+
if spec.type is dict:
|
|
48
|
+
# This is a dictionary.
|
|
49
|
+
return dict(zip(spec.context, subtrees))
|
|
50
|
+
if spec.type is tuple:
|
|
51
|
+
return tuple(subtrees)
|
|
52
|
+
if spec.type is list:
|
|
53
|
+
return list(subtrees)
|
|
54
|
+
if spec.type is None and not subtrees:
|
|
55
|
+
return None
|
|
56
|
+
if spec.context:
|
|
57
|
+
# This is a custom class with attributes.
|
|
58
|
+
# It is returned as a list.
|
|
59
|
+
return list(subtrees)
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Unable to interpret spec type {spec.type} "
|
|
62
|
+
f"(type is {type(spec.type)}, context is {spec.context}), "
|
|
63
|
+
f"spec={spec}, subtrees={subtrees}"
|
|
64
|
+
)
|
|
65
|
+
# This is a list.
|
|
66
|
+
return subtrees
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]:
|
|
70
|
+
"""
|
|
71
|
+
Returns the list of dynamic dimensions given a list of shapes
|
|
72
|
+
corresponding to the same tensor.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
shape_list:
|
|
76
|
+
list of shapes, they must all have the same length
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
list of dynamic dimensions
|
|
80
|
+
"""
|
|
81
|
+
unique_ranks = {len(shape) for shape in shape_list}
|
|
82
|
+
torch._check(
|
|
83
|
+
len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank"
|
|
84
|
+
)
|
|
85
|
+
rank = unique_ranks.pop()
|
|
86
|
+
dynamic = []
|
|
87
|
+
for i in range(rank):
|
|
88
|
+
dims = [shape[i] for shape in shape_list]
|
|
89
|
+
if len(set(dims)) > 1:
|
|
90
|
+
dynamic.append(i)
|
|
91
|
+
return dynamic
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class InputObserverInfo:
|
|
95
|
+
def __init__(self, signature: inspect.Signature):
|
|
96
|
+
# pyrefly: ignore
|
|
97
|
+
self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
|
|
98
|
+
self.flat_inputs: list[list[torch.Tensor | None]] = []
|
|
99
|
+
# pyrefly: ignore
|
|
100
|
+
self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
|
|
101
|
+
self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = []
|
|
102
|
+
self.signature = signature
|
|
103
|
+
|
|
104
|
+
self._max_args: tuple[Any, torch.Tensor] | None = None
|
|
105
|
+
self._max_kwargs: dict[str, torch.Tensor] | None = None
|
|
106
|
+
|
|
107
|
+
def __len__(self) -> int:
|
|
108
|
+
return len(self.flat_inputs)
|
|
109
|
+
|
|
110
|
+
def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
|
|
111
|
+
kwargs = {
|
|
112
|
+
k: v
|
|
113
|
+
for k, v in kwargs.items()
|
|
114
|
+
if v is not None and not isinstance(v, (int, float, bool))
|
|
115
|
+
}
|
|
116
|
+
flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs))
|
|
117
|
+
self.inputs_specs.append(spec)
|
|
118
|
+
cloned = [
|
|
119
|
+
(None if not isinstance(t, torch.Tensor) else t.clone().detach())
|
|
120
|
+
for t in flat_args
|
|
121
|
+
]
|
|
122
|
+
self.flat_inputs.append(cloned)
|
|
123
|
+
|
|
124
|
+
cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)
|
|
125
|
+
if self._max_args is None or len(cloned_args) > len(self._max_args):
|
|
126
|
+
self._max_args = cloned_args
|
|
127
|
+
if self._max_kwargs is None or len(cloned_kwargs) > len(self._max_kwargs):
|
|
128
|
+
self._max_kwargs = cloned_kwargs
|
|
129
|
+
|
|
130
|
+
def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]):
|
|
131
|
+
flat_res, spec = torch.utils._pytree.tree_flatten(res)
|
|
132
|
+
self.outputs_specs.append(spec)
|
|
133
|
+
self.flat_outputs.append([t.clone().detach() for t in flat_res])
|
|
134
|
+
|
|
135
|
+
def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]:
|
|
136
|
+
# Let's compute the sizes of each independently.
|
|
137
|
+
if not self.flat_inputs or self._max_args is None or self._max_kwargs is None:
|
|
138
|
+
raise RuntimeError("No inputs were captured.")
|
|
139
|
+
arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args]
|
|
140
|
+
kwarg_sizes = {
|
|
141
|
+
k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items()
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Let's reprocess everything.
|
|
145
|
+
captured_inputs: dict[int | str, int] = {}
|
|
146
|
+
new_flat_inputs = []
|
|
147
|
+
for args_kwargs, spec in zip(self.flat_inputs, self.inputs_specs):
|
|
148
|
+
args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
|
|
149
|
+
if len(set(kwargs) | set(self._max_kwargs)) > len(self._max_kwargs):
|
|
150
|
+
raise RuntimeError(
|
|
151
|
+
"At least one call to the observed model "
|
|
152
|
+
"must contain all the named arguments."
|
|
153
|
+
)
|
|
154
|
+
flat = []
|
|
155
|
+
for i in range(len(self._max_args)):
|
|
156
|
+
if i < len(args):
|
|
157
|
+
ts = torch.utils._pytree.tree_flatten(args[i])[0]
|
|
158
|
+
if i in captured_inputs and captured_inputs[i] != len(ts):
|
|
159
|
+
raise RuntimeError(
|
|
160
|
+
f"Positional argument {i} has {len(ts)} tensors "
|
|
161
|
+
f"but previously got {captured_inputs[i]} tensors. "
|
|
162
|
+
f"Inference is impossible in that case."
|
|
163
|
+
)
|
|
164
|
+
captured_inputs[i] = len(ts)
|
|
165
|
+
flat.extend(ts)
|
|
166
|
+
else:
|
|
167
|
+
flat.extend([None for _ in range(arg_sizes[i])])
|
|
168
|
+
for k in self._max_kwargs:
|
|
169
|
+
if k in kwargs:
|
|
170
|
+
ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
|
|
171
|
+
if k in captured_inputs and captured_inputs[k] != len(ts):
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Named argument {k!r} has {len(ts)} tensors "
|
|
174
|
+
f"but previously got {captured_inputs[k]} tensors. "
|
|
175
|
+
f"Inference is impossible in that case."
|
|
176
|
+
)
|
|
177
|
+
captured_inputs[k] = len(ts)
|
|
178
|
+
flat.extend(ts)
|
|
179
|
+
else:
|
|
180
|
+
flat.extend([None for _ in range(kwarg_sizes[k])])
|
|
181
|
+
new_flat_inputs.append(flat)
|
|
182
|
+
return new_flat_inputs
|
|
183
|
+
|
|
184
|
+
def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
|
|
185
|
+
flat_inputs = self.build_inputs_completed_with_none_values()
|
|
186
|
+
# This is already checked by build_inputs_completed_with_none_values
|
|
187
|
+
# but this is not always well captured by tools checking types.
|
|
188
|
+
assert self._max_args is not None and self._max_kwargs is not None
|
|
189
|
+
if len({len(flat) for flat in flat_inputs}) != 1:
|
|
190
|
+
raise NotImplementedError(
|
|
191
|
+
"infer_dynamic_shapes is not implemented "
|
|
192
|
+
"when the number of input tensors are not the same."
|
|
193
|
+
)
|
|
194
|
+
shape_lists = [
|
|
195
|
+
[(None if t is None else t.shape) for t in tensors] for tensors in flat_inputs
|
|
196
|
+
]
|
|
197
|
+
n_tensors = len(shape_lists[0])
|
|
198
|
+
dynamic_shapes = [
|
|
199
|
+
infer_dynamic_dimensions(
|
|
200
|
+
[s for s in [shapes[index] for shapes in shape_lists] if s is not None]
|
|
201
|
+
)
|
|
202
|
+
for index in range(n_tensors)
|
|
203
|
+
]
|
|
204
|
+
cst = torch.export.Dim.DYNAMIC
|
|
205
|
+
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
|
|
206
|
+
if len(flat_dynamic_shapes) == len(self._max_args) + len(self._max_kwargs):
|
|
207
|
+
# It means forward method is called with tensors only.
|
|
208
|
+
if not self._max_kwargs:
|
|
209
|
+
# only positional arguments
|
|
210
|
+
return tuple(flat_dynamic_shapes)
|
|
211
|
+
if not self._max_args:
|
|
212
|
+
# only named arguments
|
|
213
|
+
return dict(zip(list(self._max_kwargs), flat_dynamic_shapes))
|
|
214
|
+
# positional arguments needs to be moved to the named arguments
|
|
215
|
+
n_args = len(self._max_args)
|
|
216
|
+
pos_names = list(self.signature.parameters)[:n_args]
|
|
217
|
+
return {
|
|
218
|
+
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
|
|
219
|
+
**dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
# nested types, here comes the fun part because the shapes cannot be unflattened,
|
|
223
|
+
# custom classes must appear in their flattened shape.
|
|
224
|
+
# This does not work in all cases but every time every available argument is flattened
|
|
225
|
+
# with the same number of tensors. The function does not check
|
|
226
|
+
# if that assumption is true.
|
|
227
|
+
flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
|
|
228
|
+
(self._max_args, self._max_kwargs)
|
|
229
|
+
)
|
|
230
|
+
torch._check(
|
|
231
|
+
len(flat_inputs) == len(flat_dynamic_shapes),
|
|
232
|
+
(
|
|
233
|
+
f"Length mismatch len(flat_inputs)={len(flat_inputs)}, "
|
|
234
|
+
f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
|
|
235
|
+
),
|
|
236
|
+
)
|
|
237
|
+
mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)}
|
|
238
|
+
ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes(
|
|
239
|
+
(self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)]
|
|
240
|
+
)
|
|
241
|
+
if not ds_kwargs:
|
|
242
|
+
return tuple(ds_args)
|
|
243
|
+
if not ds_args:
|
|
244
|
+
return tuple(ds_kwargs)
|
|
245
|
+
pos_names = list(self.signature.parameters)[: len(ds_args)]
|
|
246
|
+
return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
|
|
247
|
+
|
|
248
|
+
def infer_arguments(
|
|
249
|
+
self, index: int | None = None
|
|
250
|
+
) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
251
|
+
# This is already checked by build_inputs_completed_with_none_values
|
|
252
|
+
# but this is not always well captured by tools checking types.
|
|
253
|
+
assert self._max_args is not None and self._max_kwargs is not None
|
|
254
|
+
candidate = None
|
|
255
|
+
if index is None:
|
|
256
|
+
for i, (args_kwargs, spec) in enumerate(zip(self.flat_inputs, self.inputs_specs)):
|
|
257
|
+
args, kwargs = torch.utils._pytree.tree_unflatten(args_kwargs, spec)
|
|
258
|
+
if len(args) == len(self._max_args) and len(kwargs) == len(self._max_kwargs):
|
|
259
|
+
index = i
|
|
260
|
+
candidate = args, kwargs
|
|
261
|
+
break
|
|
262
|
+
if index is not None:
|
|
263
|
+
# found one available set.
|
|
264
|
+
args, kwargs = candidate or torch.utils._pytree.tree_unflatten(
|
|
265
|
+
self.flat_inputs[index], self.inputs_specs[index]
|
|
266
|
+
)
|
|
267
|
+
if not kwargs:
|
|
268
|
+
return args
|
|
269
|
+
if not args:
|
|
270
|
+
return kwargs
|
|
271
|
+
# We need to move args to kwargs
|
|
272
|
+
pos_names = list(self.signature.parameters)[: len(args)]
|
|
273
|
+
return {**dict(zip(pos_names, args)), **kwargs}
|
|
274
|
+
|
|
275
|
+
raise NotImplementedError(
|
|
276
|
+
"We could not find a good set of inputs/outputs. "
|
|
277
|
+
"We need to replace none by empty tensors."
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class InputObserver:
|
|
282
|
+
def __init__(self, store_n_calls: int = 3):
|
|
283
|
+
self.store_n_calls = store_n_calls
|
|
284
|
+
self.info: InputObserverInfo | None = None
|
|
285
|
+
|
|
286
|
+
def _forward_captured(self, *args, _captured_forward=None, **kwargs):
|
|
287
|
+
assert _captured_forward is not None, "_captured_forward cannot be None"
|
|
288
|
+
assert self.info is not None, "info cannot be None"
|
|
289
|
+
n_stored = len(self.info)
|
|
290
|
+
if n_stored < self.store_n_calls:
|
|
291
|
+
self.info.add_inputs(args, kwargs)
|
|
292
|
+
res = _captured_forward(*args, **kwargs)
|
|
293
|
+
if n_stored < self.store_n_calls:
|
|
294
|
+
self.info.add_outputs(res)
|
|
295
|
+
return res
|
|
296
|
+
|
|
297
|
+
@contextlib.contextmanager
|
|
298
|
+
def __call__(self, model: torch.nn.Module):
|
|
299
|
+
if self.info is not None:
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
"This class was already used to capture a model. Please create a new one."
|
|
302
|
+
)
|
|
303
|
+
self.info = InputObserverInfo(signature=inspect.signature(model.forward))
|
|
304
|
+
forward_method = model.forward
|
|
305
|
+
model.forward = (
|
|
306
|
+
lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured(
|
|
307
|
+
*args, _captured_forward=_captured_forward, **kwargs
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
try:
|
|
311
|
+
yield self
|
|
312
|
+
finally:
|
|
313
|
+
model.forward = forward_method
|
|
314
|
+
|
|
315
|
+
def _check_captured(self):
|
|
316
|
+
if self.info is None:
|
|
317
|
+
raise RuntimeError("No inputs were captured.")
|
|
318
|
+
|
|
319
|
+
def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]:
|
|
320
|
+
self._check_captured()
|
|
321
|
+
assert self.info is not None # missed by type checking
|
|
322
|
+
return self.info.infer_dynamic_shapes()
|
|
323
|
+
|
|
324
|
+
def infer_arguments(
|
|
325
|
+
self, index: int | None = None
|
|
326
|
+
) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
|
|
327
|
+
self._check_captured()
|
|
328
|
+
assert self.info is not None # missed by type checking
|
|
329
|
+
return self.info.infer_arguments(index=index)
|
|
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
|
|
|
34
34
|
from .report_results_comparison import ReportResultComparison
|
|
35
35
|
from .evaluator import ExtendedReferenceEvaluator
|
|
36
36
|
|
|
37
|
-
|
|
38
37
|
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
39
38
|
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
40
39
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Tuple, Union
|
|
2
|
-
|
|
1
|
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
3
2
|
|
|
4
3
|
ReportKeyNameType = Union[str, Tuple[str, int, str]]
|
|
5
4
|
ReportKeyValueType = Tuple[int, Tuple[int, ...]]
|
|
@@ -14,6 +13,7 @@ class ReportResultComparison:
|
|
|
14
13
|
:param tensors: tensor
|
|
15
14
|
"""
|
|
16
15
|
|
|
16
|
+
# pyrefly: ignore[unknown-name]
|
|
17
17
|
def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
|
|
18
18
|
from ..helpers.onnx_helper import dtype_to_tensor_dtype
|
|
19
19
|
from ..helpers import max_diff, string_type
|
|
@@ -25,7 +25,9 @@ class ReportResultComparison:
|
|
|
25
25
|
self.max_diff = max_diff
|
|
26
26
|
self.tensors = tensors
|
|
27
27
|
self._build_mapping()
|
|
28
|
+
self.unique_run_names: Set[str] = set()
|
|
28
29
|
|
|
30
|
+
# pyrefly: ignore[unknown-name]
|
|
29
31
|
def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
|
|
30
32
|
"Returns a key for a tensor, (onnx dtype, shape)."
|
|
31
33
|
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
|
|
@@ -59,12 +61,15 @@ class ReportResultComparison:
|
|
|
59
61
|
for k, v in self.value.items():
|
|
60
62
|
(i_run, run_name), ref_name = k
|
|
61
63
|
d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
|
|
64
|
+
# pyrefly: ignore[no-matching-overload]
|
|
62
65
|
d.update(v)
|
|
63
66
|
rows.append(d)
|
|
64
67
|
return rows
|
|
65
68
|
|
|
66
69
|
def report(
|
|
67
|
-
self,
|
|
70
|
+
self,
|
|
71
|
+
# pyrefly: ignore[unknown-name]
|
|
72
|
+
outputs: Dict[str, "torch.Tensor"], # noqa: F821
|
|
68
73
|
) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
|
|
69
74
|
"""
|
|
70
75
|
For every tensor in outputs, compares it to every tensor held by
|
|
@@ -79,6 +84,7 @@ class ReportResultComparison:
|
|
|
79
84
|
key = self.key(tensor)
|
|
80
85
|
if key not in self.mapping:
|
|
81
86
|
continue
|
|
87
|
+
# pyrefly: ignore[unknown-name]
|
|
82
88
|
cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
|
|
83
89
|
for held_key in self.mapping[key]:
|
|
84
90
|
t2 = self.tensors[held_key]
|
|
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
|
|
|
63
63
|
* `functions`: local functions
|
|
64
64
|
|
|
65
65
|
The class is not multithreaded. `runtime_info` gets updated
|
|
66
|
-
by the
|
|
66
|
+
by the class. The list of available kernels is returned by function
|
|
67
67
|
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
|
|
68
68
|
Example:
|
|
69
69
|
|
|
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
|
|
|
494
494
|
r = self.runtime_info[k]
|
|
495
495
|
r.set_value(
|
|
496
496
|
torch_ops.OpRunTensor(
|
|
497
|
+
# pyrefly: ignore[missing-attribute]
|
|
497
498
|
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
|
|
498
499
|
is_constant=False,
|
|
500
|
+
# pyrefly: ignore[missing-attribute]
|
|
499
501
|
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
500
502
|
)
|
|
501
503
|
)
|
|
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
|
|
|
524
526
|
f"for kernel {type(kernel)}."
|
|
525
527
|
)
|
|
526
528
|
for name, t in zip(kernel.output, res):
|
|
529
|
+
# pyrefly: ignore[bad-argument-type]
|
|
527
530
|
self.runtime_info[name].set_value(t)
|
|
528
531
|
if self.verbose:
|
|
529
532
|
for name in kernel.output:
|
|
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
|
|
|
644
647
|
f"for kernel {type(kernel)}."
|
|
645
648
|
)
|
|
646
649
|
for name, t in zip(kernel.output, res):
|
|
650
|
+
# pyrefly: ignore[bad-argument-type]
|
|
647
651
|
self.runtime_info[name].set_value(t)
|
|
648
652
|
else:
|
|
649
653
|
assert isinstance(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
2
2
|
import onnx
|
|
3
3
|
import torch
|
|
4
|
-
from ...
|
|
4
|
+
from ...typing import TensorLike
|
|
5
5
|
from ...helpers import string_type
|
|
6
6
|
from ...helpers.torch_helper import to_tensor
|
|
7
7
|
|
|
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
|
|
|
149
149
|
) -> "OpRunSequence":
|
|
150
150
|
"Inserts a value at a given position."
|
|
151
151
|
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
|
|
152
|
-
new_seq = OpRunSequence()
|
|
152
|
+
new_seq = OpRunSequence() # type: ignore[abstract]
|
|
153
153
|
seq = self.sequence.copy()
|
|
154
154
|
new_seq.sequence = seq
|
|
155
155
|
if position is None:
|
|
@@ -314,9 +314,7 @@ class OpRunKernel:
|
|
|
314
314
|
|
|
315
315
|
|
|
316
316
|
class OpRunFunction(OpRunKernel):
|
|
317
|
-
"""
|
|
318
|
-
Defines a kernel based on a local functions.
|
|
319
|
-
"""
|
|
317
|
+
"""Defines a kernel based on a local functions."""
|
|
320
318
|
|
|
321
319
|
def __init__(
|
|
322
320
|
self,
|
|
@@ -7,10 +7,10 @@ import transformers
|
|
|
7
7
|
|
|
8
8
|
def patched__compute_dynamic_ntk_parameters(
|
|
9
9
|
config: Optional[transformers.PretrainedConfig] = None,
|
|
10
|
-
device: Optional[
|
|
10
|
+
device: Optional[torch.device] = None,
|
|
11
11
|
seq_len: Optional[int] = None,
|
|
12
12
|
**rope_kwargs,
|
|
13
|
-
) -> Tuple[
|
|
13
|
+
) -> Tuple[torch.Tensor, float]:
|
|
14
14
|
"""
|
|
15
15
|
manual patch:
|
|
16
16
|
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from typing import Any, Callable, List, Set, Tuple
|
|
3
3
|
import torch
|
|
4
|
+
import transformers.cache_utils
|
|
4
5
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
|
5
6
|
|
|
6
7
|
try:
|
|
@@ -22,22 +23,43 @@ from transformers.modeling_outputs import BaseModelOutput
|
|
|
22
23
|
from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
|
|
23
24
|
from . import make_serialization_function_for_dataclass
|
|
24
25
|
|
|
25
|
-
|
|
26
26
|
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
27
27
|
WRONG_REGISTRATIONS = {
|
|
28
28
|
DynamicCache: "4.50",
|
|
29
29
|
BaseModelOutput: None,
|
|
30
30
|
}
|
|
31
|
+
SHORTEN_LAYER_NAMES = {
|
|
32
|
+
"DynamicLayer": "D",
|
|
33
|
+
"DynamicSlidingWindowLayer": "W",
|
|
34
|
+
"StaticLayer": "S",
|
|
35
|
+
"StaticSlidingWindowLayer": "X",
|
|
36
|
+
"D": "DynamicLayer",
|
|
37
|
+
"W": "DynamicSlidingWindowLayer",
|
|
38
|
+
"S": "StaticLayer",
|
|
39
|
+
"X": "StaticSlidingWindowLayer",
|
|
40
|
+
}
|
|
31
41
|
|
|
32
42
|
|
|
33
43
|
def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
34
44
|
ca = CacheKeyValue(cache)
|
|
35
45
|
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
46
|
+
unique = set(ca.cls_layers) if ca.cls_layers else None
|
|
47
|
+
if (
|
|
48
|
+
cache.__class__.__name__ != "DynamicCache"
|
|
49
|
+
or unique is None
|
|
50
|
+
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
|
|
51
|
+
):
|
|
52
|
+
keys = list(
|
|
53
|
+
itertools.chain.from_iterable(
|
|
54
|
+
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
|
|
55
|
+
)
|
|
39
56
|
)
|
|
40
|
-
|
|
57
|
+
return flat, keys
|
|
58
|
+
|
|
59
|
+
keys = []
|
|
60
|
+
for i in range(len(ca.key_cache)):
|
|
61
|
+
letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
|
|
62
|
+
keys.extend([f"key_{letter}{i}", f"value_{letter}{i}"])
|
|
41
63
|
return flat, keys
|
|
42
64
|
|
|
43
65
|
|
|
@@ -55,7 +77,20 @@ def _unflatten_cache(
|
|
|
55
77
|
output_type=None,
|
|
56
78
|
) -> DynamicCache:
|
|
57
79
|
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
58
|
-
|
|
80
|
+
expected = list(
|
|
81
|
+
itertools.chain.from_iterable(
|
|
82
|
+
(f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
if expected == context:
|
|
86
|
+
res = make_cache(list(zip(values[::2], values[1::2])))
|
|
87
|
+
else:
|
|
88
|
+
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
|
|
89
|
+
cls_layers = [
|
|
90
|
+
getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
|
|
91
|
+
]
|
|
92
|
+
res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)
|
|
93
|
+
|
|
59
94
|
assert output_type is None or isinstance(
|
|
60
95
|
res, output_type
|
|
61
96
|
), f"Type mismatch between {output_type} (expected) and {type(res)}"
|
|
@@ -71,14 +106,6 @@ def flatten_dynamic_cache(
|
|
|
71
106
|
dynamic_cache: DynamicCache,
|
|
72
107
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
73
108
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
74
|
-
assert (
|
|
75
|
-
not hasattr(dynamic_cache, "layers")
|
|
76
|
-
or not dynamic_cache.layers
|
|
77
|
-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
78
|
-
), (
|
|
79
|
-
f"The serialization does not work yet on other layers "
|
|
80
|
-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
81
|
-
)
|
|
82
109
|
return _flatten_key_value_cache(dynamic_cache)
|
|
83
110
|
|
|
84
111
|
|
|
@@ -86,14 +113,6 @@ def flatten_with_keys_dynamic_cache(
|
|
|
86
113
|
dynamic_cache: DynamicCache,
|
|
87
114
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
88
115
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
89
|
-
assert (
|
|
90
|
-
not hasattr(dynamic_cache, "layers")
|
|
91
|
-
or not dynamic_cache.layers
|
|
92
|
-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
93
|
-
), (
|
|
94
|
-
f"The serialization does not work yet on other layers "
|
|
95
|
-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
96
|
-
)
|
|
97
116
|
return _flatten_with_keys_cache(dynamic_cache)
|
|
98
117
|
|
|
99
118
|
|
|
@@ -161,7 +180,9 @@ def unflatten_static_cache(
|
|
|
161
180
|
) -> StaticCache:
|
|
162
181
|
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
163
182
|
return _unflatten_cache(
|
|
164
|
-
lambda *args: make_static_cache(
|
|
183
|
+
lambda *args, **kwargs: make_static_cache(
|
|
184
|
+
*args, max_cache_len=values[0].shape[2], **kwargs
|
|
185
|
+
),
|
|
165
186
|
values,
|
|
166
187
|
context,
|
|
167
188
|
output_type=output_type,
|
|
@@ -8,11 +8,9 @@ from .hghub.model_inputs import _preprocess_model_id
|
|
|
8
8
|
from .hghub import get_untrained_model_with_inputs
|
|
9
9
|
from .validate import filter_inputs, make_patch_kwargs
|
|
10
10
|
|
|
11
|
-
|
|
12
11
|
CODE_SAMPLES = {
|
|
13
12
|
"imports": "from typing import Any\nimport torch",
|
|
14
|
-
"get_model_with_inputs": textwrap.dedent(
|
|
15
|
-
"""
|
|
13
|
+
"get_model_with_inputs": textwrap.dedent("""
|
|
16
14
|
def get_model_with_inputs(
|
|
17
15
|
model_id:str,
|
|
18
16
|
subfolder: str | None = None,
|
|
@@ -57,8 +55,7 @@ CODE_SAMPLES = {
|
|
|
57
55
|
if device:
|
|
58
56
|
data["model"] = data["model"].to(device)
|
|
59
57
|
return data["model"]
|
|
60
|
-
"""
|
|
61
|
-
),
|
|
58
|
+
"""),
|
|
62
59
|
}
|
|
63
60
|
|
|
64
61
|
|
|
@@ -198,7 +195,7 @@ def code_sample(
|
|
|
198
195
|
this is not always possible
|
|
199
196
|
:param use_pretrained: use the trained model, not the untrained one
|
|
200
197
|
:param optimization: optimization to apply to the exported model,
|
|
201
|
-
depend on the
|
|
198
|
+
depend on the exporter
|
|
202
199
|
:param quiet: if quiet, catches exception if any issue
|
|
203
200
|
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
204
201
|
if True before exporting
|
|
@@ -326,11 +323,9 @@ def code_sample(
|
|
|
326
323
|
imports,
|
|
327
324
|
cache_import,
|
|
328
325
|
CODE_SAMPLES["get_model_with_inputs"],
|
|
329
|
-
textwrap.dedent(
|
|
330
|
-
f"""
|
|
326
|
+
textwrap.dedent(f"""
|
|
331
327
|
model = get_model_with_inputs({model_args})
|
|
332
|
-
"""
|
|
333
|
-
),
|
|
328
|
+
"""),
|
|
334
329
|
f"inputs = {input_code}",
|
|
335
330
|
exporter_code,
|
|
336
331
|
]
|