onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +755 -5
- onnx_diagnostic/export/dynamic_shapes.py +61 -4
- onnx_diagnostic/export/shape_helper.py +1 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -21
- onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
- onnx_diagnostic/helpers/helper.py +36 -6
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/ort_session.py +5 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +8 -5
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
- onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- onnx_diagnostic/torch_models/validate.py +48 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
|
@@ -329,7 +329,7 @@ class CoupleInputsDynamicShapes:
|
|
|
329
329
|
if type(inputs) in (tuple, list, dict):
|
|
330
330
|
# Type must be strict, some custom classes can inherit from those.
|
|
331
331
|
assert type(inputs) is type(ds), (
|
|
332
|
-
f"Input type and dynamic
|
|
332
|
+
f"Input type and dynamic shapes type mush match but "
|
|
333
333
|
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
|
|
334
334
|
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
335
335
|
)
|
|
@@ -352,6 +352,19 @@ class CoupleInputsDynamicShapes:
|
|
|
352
352
|
else None
|
|
353
353
|
)
|
|
354
354
|
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
|
|
355
|
+
if set(inputs) != set(ds):
|
|
356
|
+
not_in_ds = {k for k in inputs if k not in ds}
|
|
357
|
+
not_in_inputs = {k for k in ds if k not in inputs}
|
|
358
|
+
assert not_in_inputs == {"kwargs"} and set(ds["kwargs"]) == not_in_ds, (
|
|
359
|
+
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
|
|
360
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}, "
|
|
361
|
+
f"not_in_ds={not_in_ds}, not_in_inputs={not_in_inputs}"
|
|
362
|
+
)
|
|
363
|
+
# Tweak...
|
|
364
|
+
kws = ds["kwargs"]
|
|
365
|
+
del ds["kwargs"]
|
|
366
|
+
ds.update(kws)
|
|
367
|
+
|
|
355
368
|
assert set(inputs) == set(ds), (
|
|
356
369
|
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
|
|
357
370
|
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
@@ -366,13 +379,15 @@ class CoupleInputsDynamicShapes:
|
|
|
366
379
|
return dvalue if dvalue else None
|
|
367
380
|
|
|
368
381
|
# A custom class.
|
|
369
|
-
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
|
|
382
|
+
assert inputs is None or inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
|
|
370
383
|
f"Class {inputs.__class__.__name__!r} was not registered using "
|
|
371
384
|
f"torch.utils._pytree.register_pytree_node, it is not possible to "
|
|
372
385
|
f"map this class with the given dynamic shapes."
|
|
373
386
|
)
|
|
374
387
|
if flatten_unflatten:
|
|
375
388
|
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
|
|
389
|
+
if isinstance(flatunflat, (list, tuple, dict)) and len(flatunflat) == 0:
|
|
390
|
+
return flatunflat
|
|
376
391
|
res = cls._generic_walker_step(
|
|
377
392
|
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
|
|
378
393
|
)
|
|
@@ -667,6 +682,11 @@ class ModelInputs:
|
|
|
667
682
|
if self.signature
|
|
668
683
|
else None
|
|
669
684
|
)
|
|
685
|
+
self.forward_parameters_kinds = (
|
|
686
|
+
{p.name: p.kind for p in self.signature.parameters.values()}
|
|
687
|
+
if self.signature
|
|
688
|
+
else None
|
|
689
|
+
)
|
|
670
690
|
self.forward_ordered_parameter_names = (
|
|
671
691
|
list(self.signature.parameters) if self.signature else None
|
|
672
692
|
)
|
|
@@ -947,6 +967,8 @@ class ModelInputs:
|
|
|
947
967
|
"""
|
|
948
968
|
Guesses the dynamic shapes for that module from two execution.
|
|
949
969
|
If there is only one execution, then that would be static dimensions.
|
|
970
|
+
If the model signature is available, the kwargs are reordered following
|
|
971
|
+
the signature order, otherwise it follows the order given in the inputs.
|
|
950
972
|
|
|
951
973
|
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
|
|
952
974
|
dimension if the number of inputs is one,
|
|
@@ -973,7 +995,13 @@ class ModelInputs:
|
|
|
973
995
|
len(s1) == 1
|
|
974
996
|
), f"Different numbers of positional arguments {s1} for {self.full_name}"
|
|
975
997
|
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
|
|
976
|
-
assert len(s2)
|
|
998
|
+
assert len(s2) > 0, f"empty {s2} for {self.full_name}"
|
|
999
|
+
if len(s2) > 1:
|
|
1000
|
+
# We need to keep the largest set of inputs, the one including all the others.
|
|
1001
|
+
sum_s2 = set()
|
|
1002
|
+
for s in s2:
|
|
1003
|
+
sum_s2 |= set(s)
|
|
1004
|
+
s2 = {tuple(sum_s2)}
|
|
977
1005
|
args = []
|
|
978
1006
|
kwargs = {}
|
|
979
1007
|
for i in range(s1.pop()):
|
|
@@ -993,12 +1021,31 @@ class ModelInputs:
|
|
|
993
1021
|
f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
|
|
994
1022
|
)
|
|
995
1023
|
|
|
996
|
-
objs = [_[1][name] for _ in self.inputs]
|
|
1024
|
+
objs = [_[1][name] for _ in self.inputs if name in _[1]]
|
|
997
1025
|
kwargs[name] = self.guess_dynamic_shape_object(
|
|
998
1026
|
*objs,
|
|
999
1027
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
|
|
1000
1028
|
msg=lambda name=name: f" failing input {name!r}",
|
|
1001
1029
|
)
|
|
1030
|
+
# reordering
|
|
1031
|
+
if kwargs:
|
|
1032
|
+
if self.forward_ordered_parameter_names:
|
|
1033
|
+
kwargs1 = {
|
|
1034
|
+
p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
|
|
1035
|
+
}
|
|
1036
|
+
kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
|
|
1037
|
+
else:
|
|
1038
|
+
# We reorder the same the way the input were given.
|
|
1039
|
+
use = None
|
|
1040
|
+
params = set(kwargs)
|
|
1041
|
+
for _args, kws in self.inputs:
|
|
1042
|
+
if set(kws) == params:
|
|
1043
|
+
use = kws
|
|
1044
|
+
break
|
|
1045
|
+
if use:
|
|
1046
|
+
ordered = list(use)
|
|
1047
|
+
kwargs = {k: kwargs[k] for k in ordered}
|
|
1048
|
+
|
|
1002
1049
|
return tuple(args), kwargs
|
|
1003
1050
|
|
|
1004
1051
|
def move_to_kwargs(
|
|
@@ -1061,6 +1108,16 @@ class ModelInputs:
|
|
|
1061
1108
|
f"and kwargs={set(kwargs)}, "
|
|
1062
1109
|
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
|
|
1063
1110
|
)
|
|
1111
|
+
if kwargs is not None and self.forward_ordered_parameter_names:
|
|
1112
|
+
kwargs1 = {
|
|
1113
|
+
p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
|
|
1114
|
+
}
|
|
1115
|
+
kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
|
|
1116
|
+
if kw_dyn is not None and self.forward_ordered_parameter_names:
|
|
1117
|
+
kw_dyn1 = {
|
|
1118
|
+
p: kw_dyn[p] for p in self.forward_ordered_parameter_names if p in kw_dyn
|
|
1119
|
+
}
|
|
1120
|
+
kw_dyn = {**kw_dyn1, **{k: v for k, v in kw_dyn.items() if k not in kw_dyn1}}
|
|
1064
1121
|
return args, kwargs, (tuple(), kw_dyn)
|
|
1065
1122
|
|
|
1066
1123
|
def validate_inputs_for_export(
|
|
@@ -47,7 +47,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
47
47
|
make_dynamic_cache,
|
|
48
48
|
make_encoder_decoder_cache,
|
|
49
49
|
make_mamba_cache,
|
|
50
|
-
make_sliding_window_cache,
|
|
51
50
|
make_static_cache,
|
|
52
51
|
)
|
|
53
52
|
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
@@ -77,13 +76,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
77
76
|
]
|
|
78
77
|
),
|
|
79
78
|
),
|
|
80
|
-
make_sliding_window_cache(
|
|
81
|
-
[
|
|
82
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
83
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
84
|
-
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
85
|
-
]
|
|
86
|
-
),
|
|
87
79
|
make_static_cache(
|
|
88
80
|
[
|
|
89
81
|
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
@@ -210,6 +202,7 @@ def make_fake_with_dynamic_dimensions(
|
|
|
210
202
|
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
|
|
211
203
|
Parameter ``existing`` is used to reused the same object when the dynamic
|
|
212
204
|
dimension is given the same name as another one.
|
|
205
|
+
This function works with caches only if ``transformers>=4.57``.
|
|
213
206
|
|
|
214
207
|
A simple tensor:
|
|
215
208
|
|
|
@@ -19,7 +19,7 @@ class CacheKeyValue:
|
|
|
19
19
|
capi.value_cache
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
def __init__(self, cache=None):
|
|
22
|
+
def __init__(self, cache=None, cls_layers=None):
|
|
23
23
|
if hasattr(cache, "layers"):
|
|
24
24
|
layers = [
|
|
25
25
|
layer
|
|
@@ -28,32 +28,52 @@ class CacheKeyValue:
|
|
|
28
28
|
]
|
|
29
29
|
self.key_cache = [layer.keys for layer in layers]
|
|
30
30
|
self.value_cache = [layer.values for layer in layers]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
f"issue with key_cache={string_type(self.key_cache)}, "
|
|
36
|
-
f"or value_cache={string_type(self.value_cache)}, "
|
|
37
|
-
f"cache.layers={string_type(cache.layers)}"
|
|
38
|
-
)
|
|
31
|
+
assert (
|
|
32
|
+
cls_layers is None
|
|
33
|
+
), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}"
|
|
34
|
+
self.cls_layers = [type(lay) for lay in cache.layers]
|
|
39
35
|
elif cache is not None and hasattr(cache, "key_cache"):
|
|
40
36
|
self.key_cache = cache.key_cache
|
|
41
37
|
self.value_cache = cache.value_cache
|
|
38
|
+
self.cls_layers = cls_layers
|
|
39
|
+
elif (
|
|
40
|
+
cache is not None
|
|
41
|
+
and isinstance(cache, list)
|
|
42
|
+
and all(isinstance(t, torch.Tensor) for t in cache)
|
|
43
|
+
):
|
|
44
|
+
self.key_cache = cache[::2]
|
|
45
|
+
self.value_cache = cache[1::2]
|
|
46
|
+
self.cls_layers = cls_layers
|
|
42
47
|
elif cache is None:
|
|
43
48
|
self.key_cache = None
|
|
44
49
|
self.value_cache = None
|
|
50
|
+
self.cls_layers = cls_layers
|
|
45
51
|
else:
|
|
46
52
|
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
47
53
|
|
|
48
54
|
def make_dynamic_cache(self):
|
|
49
55
|
"""Does the reverse operation."""
|
|
50
|
-
return make_dynamic_cache(
|
|
56
|
+
return make_dynamic_cache(
|
|
57
|
+
list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers
|
|
58
|
+
)
|
|
51
59
|
|
|
52
60
|
@property
|
|
53
61
|
def n_layers(self) -> int:
|
|
54
62
|
"""Returns the number of layers."""
|
|
55
63
|
return len(self.key_cache) if self.key_cache else 0
|
|
56
64
|
|
|
65
|
+
def __len__(self) -> int:
|
|
66
|
+
"Returns the number of tensors."
|
|
67
|
+
return len(self.key_cache) + len(self.value_cache)
|
|
68
|
+
|
|
69
|
+
def aslist(self) -> List[torch.Tensor]:
|
|
70
|
+
"Returns tensors in a list."
|
|
71
|
+
res = []
|
|
72
|
+
for i in range(self.n_layers):
|
|
73
|
+
res.append(self.key_cache[i])
|
|
74
|
+
res.append(self.value_cache[i])
|
|
75
|
+
return res
|
|
76
|
+
|
|
57
77
|
|
|
58
78
|
def flatten_unflatten_for_dynamic_shapes(
|
|
59
79
|
obj: Any,
|
|
@@ -164,12 +184,16 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
164
184
|
|
|
165
185
|
def make_dynamic_cache(
|
|
166
186
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
187
|
+
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
167
188
|
) -> transformers.cache_utils.DynamicCache:
|
|
168
189
|
"""
|
|
169
190
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
170
191
|
This version is valid for ``transformers >= 4.50``.
|
|
171
192
|
|
|
172
193
|
:param key_value_pairs: list of pairs of (key, values)
|
|
194
|
+
:param cls_layers: to select the appropriate class to use on each layer,
|
|
195
|
+
if specified, sliding_window is ignored, it can be a string
|
|
196
|
+
if all layers are expected to follow the same class
|
|
173
197
|
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
174
198
|
|
|
175
199
|
Example:
|
|
@@ -200,15 +224,49 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
200
224
|
are supported.
|
|
201
225
|
"""
|
|
202
226
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
227
|
+
cls_kwargs = {}
|
|
228
|
+
if isinstance(cls_layers, str):
|
|
229
|
+
assert hasattr(
|
|
230
|
+
transformers.cache_utils, cls_layers
|
|
231
|
+
), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
|
|
232
|
+
cls_layer = getattr(transformers.cache_utils, cls_layers)
|
|
233
|
+
if cls_layers == "DynamicSlidingWindowLayer":
|
|
234
|
+
cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
|
|
235
|
+
assert isinstance(
|
|
236
|
+
cls_kwargs["sliding_window"], int
|
|
237
|
+
), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
|
|
238
|
+
elif cls_layers is not None:
|
|
239
|
+
unique = set(cls_layers)
|
|
240
|
+
assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
|
|
241
|
+
cls_layer = unique.pop()
|
|
242
|
+
if (
|
|
243
|
+
hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
|
|
244
|
+
and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
|
|
245
|
+
):
|
|
246
|
+
from .helper import string_type
|
|
247
|
+
|
|
248
|
+
assert key_value_pairs and key_value_pairs[0], (
|
|
249
|
+
f"not implemented for key_value_pairs="
|
|
250
|
+
f"{string_type(key_value_pairs, with_shape=True)}"
|
|
251
|
+
)
|
|
252
|
+
cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
|
|
253
|
+
assert isinstance(
|
|
254
|
+
cls_kwargs["sliding_window"], int
|
|
255
|
+
), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
|
|
256
|
+
else:
|
|
257
|
+
cls_layer = (
|
|
258
|
+
transformers.cache_utils.DynamicLayer
|
|
259
|
+
if hasattr(transformers.cache_utils, "DynamicLayer")
|
|
260
|
+
else None
|
|
261
|
+
)
|
|
262
|
+
|
|
203
263
|
if (
|
|
204
264
|
key_value_pairs
|
|
205
265
|
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
206
266
|
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
207
267
|
):
|
|
208
268
|
cache = transformers.cache_utils.DynamicCache()
|
|
209
|
-
cache.layers.extend(
|
|
210
|
-
[transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
|
|
211
|
-
)
|
|
269
|
+
cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
|
|
212
270
|
for i, layer in enumerate(cache.layers):
|
|
213
271
|
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
214
272
|
layer.dtype = k.dtype
|
|
@@ -222,14 +280,21 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
222
280
|
)
|
|
223
281
|
return finalize_cache(cache)
|
|
224
282
|
|
|
225
|
-
cache = transformers.cache_utils.DynamicCache(
|
|
226
|
-
if hasattr(cache, "layers") and
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
cache
|
|
283
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
284
|
+
if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
|
|
285
|
+
cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
|
|
286
|
+
for i, layer in enumerate(cache.layers):
|
|
287
|
+
layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
288
|
+
layer.is_initialized = True
|
|
289
|
+
else:
|
|
290
|
+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
|
|
291
|
+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
|
|
292
|
+
# The cache constructor contains the two following lines
|
|
293
|
+
# (in cache_utils.py) which append empty layers when the cache is
|
|
294
|
+
# initialized. We need to remove them.
|
|
295
|
+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
|
|
296
|
+
# self.append_new_layers(self.num_hidden_layers - 1)
|
|
297
|
+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
|
|
233
298
|
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
|
|
234
299
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
235
300
|
f"{len(key_value_pairs)} expected."
|
|
@@ -240,6 +305,7 @@ else:
|
|
|
240
305
|
|
|
241
306
|
def make_dynamic_cache(
|
|
242
307
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
308
|
+
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
243
309
|
) -> transformers.cache_utils.DynamicCache:
|
|
244
310
|
"""
|
|
245
311
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -271,6 +337,7 @@ else:
|
|
|
271
337
|
)
|
|
272
338
|
print(string_type(past_key_values, with_shape=True))
|
|
273
339
|
"""
|
|
340
|
+
assert not cls_layers, "cls_layers cannot be used for transformers<5."
|
|
274
341
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
275
342
|
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
276
343
|
for i, (key, value) in enumerate(key_value_pairs):
|
|
@@ -516,9 +583,13 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
|
|
|
516
583
|
)
|
|
517
584
|
return finalize_cache(cache)
|
|
518
585
|
|
|
586
|
+
def get_make_hybrid_cache():
|
|
587
|
+
return make_sliding_window_cache
|
|
588
|
+
|
|
519
589
|
else:
|
|
520
590
|
make_sliding_window_cache = None # type: ignore[assignment]
|
|
521
591
|
|
|
592
|
+
|
|
522
593
|
if hasattr(transformers.cache_utils, "HybridCache"):
|
|
523
594
|
|
|
524
595
|
def make_hybrid_cache(
|
|
@@ -680,9 +751,15 @@ if hasattr(transformers.cache_utils, "HybridCache"):
|
|
|
680
751
|
)
|
|
681
752
|
return finalize_cache(cache)
|
|
682
753
|
|
|
754
|
+
def get_make_hybrid_cache():
|
|
755
|
+
return make_hybrid_cache
|
|
756
|
+
|
|
683
757
|
else:
|
|
684
758
|
make_hybrid_cache = None # type: ignore[assignment]
|
|
685
759
|
|
|
760
|
+
def get_make_hybrid_cache():
|
|
761
|
+
return None
|
|
762
|
+
|
|
686
763
|
|
|
687
764
|
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
688
765
|
"""
|
|
@@ -105,6 +105,8 @@ class FakeTensorContext:
|
|
|
105
105
|
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
|
|
106
106
|
axis=tuple(sorted(sh)), keepdim=True
|
|
107
107
|
)
|
|
108
|
+
if len(reduced_tensor.shape) == 0 == len(new_shape):
|
|
109
|
+
return reduced_tensor
|
|
108
110
|
return reduced_tensor.expand(*new_shape)
|
|
109
111
|
|
|
110
112
|
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
|
|
@@ -144,19 +146,22 @@ class FakeTensorContext:
|
|
|
144
146
|
"""
|
|
145
147
|
See
|
|
146
148
|
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
|
|
149
|
+
If caches are used, it requires ``transformers>=4.57``.
|
|
147
150
|
"""
|
|
148
151
|
if x is None:
|
|
149
152
|
return None, None
|
|
150
|
-
if
|
|
153
|
+
if type(x) in (list, tuple):
|
|
151
154
|
return x.__class__(
|
|
152
155
|
[
|
|
153
156
|
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
|
|
154
157
|
for i, ds in zip(x, dynamic_shapes)
|
|
155
158
|
]
|
|
156
159
|
)
|
|
157
|
-
if
|
|
160
|
+
if type(x) is dict:
|
|
158
161
|
return {
|
|
159
|
-
k: self.make_fake_with_dynamic_dimensions(
|
|
162
|
+
k: self.make_fake_with_dynamic_dimensions(
|
|
163
|
+
v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
|
|
164
|
+
)
|
|
160
165
|
for k, v in x.items()
|
|
161
166
|
}
|
|
162
167
|
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
@@ -187,6 +192,17 @@ class FakeTensorContext:
|
|
|
187
192
|
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
|
|
188
193
|
)
|
|
189
194
|
return x
|
|
195
|
+
if x.__class__.__name__ == "BaseModelOutput":
|
|
196
|
+
assert (
|
|
197
|
+
list(x.keys()) == ["last_hidden_state"] and x.last_hidden_state is not None
|
|
198
|
+
), (
|
|
199
|
+
f"Field 'last_hidden_state' is empty for {type(x)} or other fields "
|
|
200
|
+
f"{list(x.keys())} are used."
|
|
201
|
+
)
|
|
202
|
+
x.last_hidden_state = self.make_fake_with_dynamic_dimensions(
|
|
203
|
+
x.last_hidden_state, dynamic_shapes=dynamic_shapes[0]
|
|
204
|
+
)
|
|
205
|
+
return x
|
|
190
206
|
if hasattr(x, "shape"):
|
|
191
207
|
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
|
|
192
208
|
f"dynamic_shapes must be a dictionary at this stage but "
|
|
@@ -197,9 +213,11 @@ class FakeTensorContext:
|
|
|
197
213
|
for idim, dim in enumerate(x.shape):
|
|
198
214
|
if dynamic_shapes is not None and idim in dynamic_shapes:
|
|
199
215
|
s = dynamic_shapes[idim]
|
|
216
|
+
if s.__class__.__name__ == "Dim":
|
|
217
|
+
s = s.__name__
|
|
200
218
|
assert isinstance(s, str), (
|
|
201
219
|
f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
|
|
202
|
-
f"at index {idim}"
|
|
220
|
+
f"at index {idim}, self._mapping_str={self._mapping_str}"
|
|
203
221
|
)
|
|
204
222
|
if s in self._mapping_str:
|
|
205
223
|
dim = self._mapping_str[s]
|
|
@@ -217,10 +235,13 @@ class FakeTensorContext:
|
|
|
217
235
|
|
|
218
236
|
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
|
|
219
237
|
|
|
220
|
-
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
|
|
238
|
+
t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
|
|
221
239
|
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
|
|
222
240
|
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
|
|
223
241
|
return t
|
|
242
|
+
if isinstance(x, (int, bool, float)):
|
|
243
|
+
# It is a constant, we don't change that.
|
|
244
|
+
return x
|
|
224
245
|
from ..helpers import string_type
|
|
225
246
|
|
|
226
247
|
raise TypeError(
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import enum
|
|
3
3
|
import inspect
|
|
4
|
-
import itertools
|
|
5
4
|
import json
|
|
6
5
|
from dataclasses import is_dataclass, fields
|
|
7
6
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
@@ -704,9 +703,35 @@ def string_type(
|
|
|
704
703
|
if obj.__class__.__name__ == "VirtualTensor":
|
|
705
704
|
if verbose:
|
|
706
705
|
print(f"[string_type] TT4:{type(obj)}")
|
|
706
|
+
|
|
707
|
+
def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821
|
|
708
|
+
if isinstance(value, str):
|
|
709
|
+
return value
|
|
710
|
+
if hasattr(value, "node") and isinstance(value.node, str):
|
|
711
|
+
return f"{value.node}"
|
|
712
|
+
|
|
713
|
+
from torch.fx.experimental.sym_node import SymNode
|
|
714
|
+
|
|
715
|
+
if hasattr(value, "node") and isinstance(value.node, SymNode):
|
|
716
|
+
# '_expr' is safer than expr
|
|
717
|
+
return str(value.node._expr).replace(" ", "")
|
|
718
|
+
|
|
719
|
+
try:
|
|
720
|
+
val_int = int(value)
|
|
721
|
+
return val_int
|
|
722
|
+
except (
|
|
723
|
+
TypeError,
|
|
724
|
+
ValueError,
|
|
725
|
+
AttributeError,
|
|
726
|
+
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
|
|
727
|
+
):
|
|
728
|
+
pass
|
|
729
|
+
|
|
730
|
+
raise AssertionError(f"Unable to convert {value!r} into string")
|
|
731
|
+
|
|
707
732
|
return (
|
|
708
733
|
f"{obj.__class__.__name__}(name={obj.name!r}, "
|
|
709
|
-
f"dtype={obj.dtype}, shape={obj.shape})"
|
|
734
|
+
f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})"
|
|
710
735
|
)
|
|
711
736
|
|
|
712
737
|
if obj.__class__.__name__ == "KeyValuesWrapper":
|
|
@@ -775,6 +800,9 @@ def string_type(
|
|
|
775
800
|
print(f"[string_type] TT8:{type(obj)}")
|
|
776
801
|
return repr(obj).replace(" ", "").replace("\n", " ")
|
|
777
802
|
|
|
803
|
+
if isinstance(obj, torch.fx.proxy.Proxy):
|
|
804
|
+
return repr(obj)
|
|
805
|
+
|
|
778
806
|
if ignore:
|
|
779
807
|
if verbose:
|
|
780
808
|
print(f"[string_type] CACHE4:{type(obj)}")
|
|
@@ -962,15 +990,17 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
962
990
|
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
963
991
|
from .cache_helper import CacheKeyValue
|
|
964
992
|
|
|
965
|
-
|
|
966
|
-
return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
|
|
993
|
+
return CacheKeyValue(x).aslist()
|
|
967
994
|
|
|
968
995
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
969
|
-
res =
|
|
996
|
+
res = [
|
|
997
|
+
*flatten_object(x.self_attention_cache),
|
|
998
|
+
*flatten_object(x.cross_attention_cache),
|
|
999
|
+
]
|
|
970
1000
|
return tuple(res)
|
|
971
1001
|
if x.__class__.__name__ == "MambaCache":
|
|
972
1002
|
if isinstance(x.conv_states, list):
|
|
973
|
-
res = flatten_object(x.conv_states)
|
|
1003
|
+
res = [*flatten_object(x.conv_states), *flatten_object(x.ssm_states)]
|
|
974
1004
|
return tuple(res)
|
|
975
1005
|
return (x.conv_states, x.ssm_states)
|
|
976
1006
|
if hasattr(x, "to_tuple"):
|
|
@@ -28,6 +28,7 @@ from onnx import (
|
|
|
28
28
|
NodeProto,
|
|
29
29
|
OperatorSetIdProto,
|
|
30
30
|
TensorProto,
|
|
31
|
+
TypeProto,
|
|
31
32
|
ValueInfoProto,
|
|
32
33
|
load as onnx_load,
|
|
33
34
|
)
|
|
@@ -385,6 +386,12 @@ def pretty_onnx(
|
|
|
385
386
|
shape_str = ",".join(map(str, shape))
|
|
386
387
|
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
|
|
387
388
|
|
|
389
|
+
if isinstance(onx, TypeProto):
|
|
390
|
+
itype = onx.tensor_type.elem_type
|
|
391
|
+
shape = tuple((d.dim_param or d.dim_value) for d in onx.tensor_type.shape.dim)
|
|
392
|
+
shape_str = ",".join(map(str, shape))
|
|
393
|
+
return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}]"
|
|
394
|
+
|
|
388
395
|
if isinstance(onx, AttributeProto):
|
|
389
396
|
att = onx
|
|
390
397
|
if att.type == AttributeProto.INT:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
3
|
import onnx
|
|
3
4
|
import numpy as np
|
|
@@ -76,6 +77,10 @@ class _InferenceSession:
|
|
|
76
77
|
session_options.enable_profiling = enable_profiling
|
|
77
78
|
if optimized_model_filepath:
|
|
78
79
|
session_options.optimized_model_filepath = optimized_model_filepath
|
|
80
|
+
session_options.add_session_config_entry(
|
|
81
|
+
"session.optimized_model_external_initializers_file_name",
|
|
82
|
+
f"{os.path.splitext(os.path.split(optimized_model_filepath)[-1])[0]}.data",
|
|
83
|
+
)
|
|
79
84
|
if log_severity_level is not None:
|
|
80
85
|
session_options.log_severity_level = log_severity_level
|
|
81
86
|
if log_verbosity_level is not None:
|
|
@@ -41,7 +41,20 @@ def make_feeds(
|
|
|
41
41
|
"""
|
|
42
42
|
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
|
|
43
43
|
# because it's fued into rotary embedding in GQA.
|
|
44
|
-
if is_modelbuilder and isinstance(inputs, dict):
|
|
44
|
+
if is_modelbuilder and isinstance(inputs, dict) and "position_ids" in inputs:
|
|
45
|
+
position_ids = inputs["position_ids"] # type: ignore[valid-type]
|
|
46
|
+
# We just check position_ids are contiguous.
|
|
47
|
+
assert isinstance(position_ids, torch.Tensor) and (
|
|
48
|
+
(
|
|
49
|
+
(position_ids - position_ids.min())
|
|
50
|
+
== torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)
|
|
51
|
+
)
|
|
52
|
+
.max()
|
|
53
|
+
.item()
|
|
54
|
+
), (
|
|
55
|
+
f"ModelBuilder does not support position_ids={position_ids}, "
|
|
56
|
+
f"inputs={string_type(inputs, with_shape=True)}"
|
|
57
|
+
)
|
|
45
58
|
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
|
|
46
59
|
|
|
47
60
|
flat = flatten_object(inputs, drop_keys=True)
|
|
@@ -15,9 +15,6 @@ from .helper import string_type, size_type
|
|
|
15
15
|
from .cache_helper import (
|
|
16
16
|
make_dynamic_cache,
|
|
17
17
|
make_encoder_decoder_cache,
|
|
18
|
-
make_hybrid_cache,
|
|
19
|
-
make_sliding_window_cache,
|
|
20
|
-
make_mamba_cache,
|
|
21
18
|
make_static_cache,
|
|
22
19
|
CacheKeyValue,
|
|
23
20
|
)
|
|
@@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
769
766
|
return {to_any(t, to_value) for t in value}
|
|
770
767
|
if type(value) is dict:
|
|
771
768
|
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
772
|
-
if value.__class__.__name__
|
|
773
|
-
make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
|
|
769
|
+
if value.__class__.__name__ == "DynamicCache":
|
|
774
770
|
cc = CacheKeyValue(value)
|
|
775
|
-
return
|
|
771
|
+
return make_dynamic_cache(
|
|
772
|
+
list(
|
|
773
|
+
zip(
|
|
774
|
+
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
775
|
+
[t.to(to_value) if t is not None else t for t in cc.value_cache],
|
|
776
|
+
)
|
|
777
|
+
),
|
|
778
|
+
cls_layers=cc.cls_layers,
|
|
779
|
+
)
|
|
780
|
+
if value.__class__.__name__ == "HybridCache":
|
|
781
|
+
from .cache_helper import make_hybrid_cache
|
|
782
|
+
|
|
783
|
+
cc = CacheKeyValue(value)
|
|
784
|
+
return make_hybrid_cache(
|
|
776
785
|
list(
|
|
777
786
|
zip(
|
|
778
787
|
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
@@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
843
852
|
from .cache_helper import CacheKeyValue
|
|
844
853
|
|
|
845
854
|
ca = CacheKeyValue(value)
|
|
846
|
-
return make_dynamic_cache(
|
|
855
|
+
return make_dynamic_cache(
|
|
856
|
+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
|
|
857
|
+
)
|
|
847
858
|
if value.__class__.__name__ == "StaticCache":
|
|
848
859
|
from .cache_helper import CacheKeyValue
|
|
849
860
|
|
|
@@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
858
869
|
max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
|
|
859
870
|
)
|
|
860
871
|
if value.__class__.__name__ == "HybridCache":
|
|
861
|
-
from .cache_helper import CacheKeyValue
|
|
872
|
+
from .cache_helper import CacheKeyValue, make_hybrid_cache
|
|
862
873
|
|
|
863
874
|
ca = CacheKeyValue(value)
|
|
864
875
|
return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
|
|
865
876
|
if value.__class__.__name__ == "SlidingWindowCache":
|
|
866
|
-
from .cache_helper import CacheKeyValue
|
|
877
|
+
from .cache_helper import CacheKeyValue, make_sliding_window_cache
|
|
867
878
|
|
|
868
879
|
ca = CacheKeyValue(value)
|
|
869
880
|
return make_sliding_window_cache(
|
|
@@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
875
886
|
torch_deepcopy(value.cross_attention_cache),
|
|
876
887
|
)
|
|
877
888
|
if value.__class__.__name__ == "MambaCache":
|
|
889
|
+
from .cache_helper import make_mamba_cache
|
|
890
|
+
|
|
878
891
|
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
|
|
879
892
|
|
|
880
893
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|